diff --git a/dependencies/requirements/base_requirements/requirements.txt b/dependencies/requirements/base_requirements/requirements.txt index bc3e68345..0f39e0507 100644 --- a/dependencies/requirements/base_requirements/requirements.txt +++ b/dependencies/requirements/base_requirements/requirements.txt @@ -8,6 +8,7 @@ flax gcsfs google-api-python-client google-cloud-aiplatform +google-cloud-mldiagnostics google-cloud-monitoring grain[parquet] huggingface_hub diff --git a/dependencies/requirements/generated_requirements/cuda12-requirements.txt b/dependencies/requirements/generated_requirements/cuda12-requirements.txt index cb04b19ad..f8b6ad4e9 100644 --- a/dependencies/requirements/generated_requirements/cuda12-requirements.txt +++ b/dependencies/requirements/generated_requirements/cuda12-requirements.txt @@ -4,7 +4,7 @@ absl-py>=2.3.1 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.1 -aiohttp>=3.13.1 +aiohttp>=3.13.2 aiosignal>=1.4.0 annotated-doc>=0.0.3 annotated-types>=0.7.0 @@ -23,7 +23,7 @@ cachetools>=6.2.1 certifi>=2025.10.5 cfgv>=3.4.0 charset-normalizer>=3.4.4 -cheroot>=11.0.0 +cheroot>=11.1.1 chex>=0.1.91 click>=8.3.0 cloud-accelerator-diagnostics>=0.1.1 @@ -34,7 +34,7 @@ colorama>=0.4.6 contourpy>=1.3.3 coverage>=7.11.0 cycler>=0.12.1 -datasets>=4.3.0 +datasets>=4.4.0 decorator>=5.2.1 dill>=0.4.0 distlib>=0.4.0 @@ -46,7 +46,7 @@ einshape>=1.0 etils>=1.13.0 evaluate>=0.4.6 execnet>=2.1.1 -fastapi>=0.120.1 +fastapi>=0.121.0 filelock>=3.20.0 flatbuffers>=25.9.23 flax>=0.12.0 @@ -55,28 +55,28 @@ frozenlist>=1.8.0 fsspec>=2025.9.0 gast>=0.6.0 gcsfs>=2025.9.0 -google-api-core>=2.28.0 -google-api-python-client>=2.185.0 -google-auth-httplib2>=0.2.0 +google-api-core>=2.28.1 +google-api-python-client>=2.186.0 +google-auth-httplib2>=0.2.1 google-auth-oauthlib>=1.2.2 -google-auth>=2.41.1 +google-auth>=2.42.1 google-benchmark>=1.9.4 -google-cloud-aiplatform>=1.122.0 +google-cloud-aiplatform>=1.124.0 google-cloud-appengine-logging>=1.7.0 google-cloud-audit-log>=0.4.0 google-cloud-bigquery>=3.38.0 -google-cloud-core>=2.4.3 +google-cloud-core>=2.5.0 google-cloud-logging>=3.12.1 +google-cloud-mldiagnostics>=0.5.0 google-cloud-monitoring>=2.28.0 google-cloud-resource-manager>=1.15.0 -google-cloud-storage>=2.19.0 +google-cloud-storage>=3.4.1 google-crc32c>=1.7.1 -google-genai>=1.46.0 -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-genai>=1.48.0 google-pasta>=0.2.0 google-resumable-media>=2.7.2 googleapis-common-protos>=1.71.0 -grain>=0.2.13 +grain>=0.2.14 grpc-google-iam-v1>=0.14.3 grpcio-status>=1.71.2 grpcio>=1.75.1 @@ -111,7 +111,7 @@ jsonlines>=4.0.0 keras>=3.11.3 kiwisolver>=1.4.9 libclang>=18.1.1 -libcst>=1.8.5 +libcst>=1.8.6 lxml>=6.0.2 markdown-it-py>=4.0.0 markdown>=3.9 @@ -122,13 +122,12 @@ mdurl>=0.1.2 ml-collections>=1.1.0 ml-dtypes>=0.5.3 ml-goodput-measurement>=0.0.15 -mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip more-itertools>=10.8.0 mpmath>=1.3.0 msgpack>=1.1.2 msgspec>=0.19.0 multidict>=6.7.0 -multiprocess>=0.70.16 +multiprocess>=0.70.18 mypy-extensions>=1.1.0 namex>=0.1.0 nest-asyncio>=1.6.0 @@ -195,8 +194,8 @@ python-dateutil>=2.9.0.post0 pytype>=2024.10.11 pytz>=2025.2 pyyaml>=6.0.3 -qwix>=0.1.1 -regex>=2025.10.23 +qwix>=0.1.2 +regex>=2025.11.3 requests-oauthlib>=2.0.0 requests>=2.32.5 rich>=14.2.0 @@ -214,7 +213,7 @@ simplejson>=3.20.2 six>=1.17.0 sniffio>=1.3.1 sortedcontainers>=2.4.0 -starlette>=0.48.0 +starlette>=0.49.3 sympy>=1.14.0 tabulate>=0.9.0 tenacity>=9.1.2 @@ -248,14 +247,14 @@ tzdata>=2025.2 uritemplate>=4.2.0 urllib3>=2.5.0 uvicorn>=0.38.0 -virtualenv>=20.35.3 +virtualenv>=20.35.4 wadler-lindig>=0.1.7 websockets>=15.0.1 werkzeug>=3.1.3 wheel>=0.45.1 wrapt>=2.0.0 -xprof>=2.20.7 +xprof>=2.21.0 xxhash>=3.6.0 yarl>=1.22.0 zipp>=3.23.0 -zstandard>=0.25.0 \ No newline at end of file +zstandard>=0.25.0 diff --git a/dependencies/requirements/generated_requirements/tpu-requirements.txt b/dependencies/requirements/generated_requirements/tpu-requirements.txt index 8e6d948c8..0df8e7d9a 100644 --- a/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -4,14 +4,14 @@ absl-py>=2.3.1 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.1 -aiohttp>=3.13.1 +aiohttp>=3.13.2 aiosignal>=1.4.0 annotated-doc>=0.0.3 annotated-types>=0.7.0 antlr4-python3-runtime>=4.9.3 anyio>=4.11.0 aqtp>=0.9.0 -array-record>=0.8.1 +array-record>=0.8.2 astroid>=4.0.1 astunparse>=1.6.3 attrs>=25.4.0 @@ -23,7 +23,7 @@ cachetools>=6.2.1 certifi>=2025.10.5 cfgv>=3.4.0 charset-normalizer>=3.4.4 -cheroot>=11.0.0 +cheroot>=11.1.1 chex>=0.1.91 click>=8.3.0 cloud-accelerator-diagnostics>=0.1.1 @@ -34,7 +34,7 @@ colorama>=0.4.6 contourpy>=1.3.3 coverage>=7.11.0 cycler>=0.12.1 -datasets>=4.3.0 +datasets>=4.4.0 decorator>=5.2.1 dill>=0.4.0 distlib>=0.4.0 @@ -46,7 +46,7 @@ einshape>=1.0 etils>=1.13.0 evaluate>=0.4.6 execnet>=2.1.1 -fastapi>=0.120.0 +fastapi>=0.121.0 filelock>=3.20.0 flatbuffers>=25.9.23 flax>=0.12.0 @@ -55,29 +55,29 @@ frozenlist>=1.8.0 fsspec>=2025.9.0 gast>=0.6.0 gcsfs>=2025.9.0 -google-api-core>=2.27.0 -google-api-python-client>=2.185.0 -google-auth-httplib2>=0.2.0 +google-api-core>=2.28.1 +google-api-python-client>=2.186.0 +google-auth-httplib2>=0.2.1 google-auth-oauthlib>=1.2.2 -google-auth>=2.41.1 +google-auth>=2.42.1 google-benchmark>=1.9.4 -google-cloud-aiplatform>=1.122.0 +google-cloud-aiplatform>=1.124.0 google-cloud-appengine-logging>=1.7.0 google-cloud-audit-log>=0.4.0 google-cloud-bigquery>=3.38.0 -google-cloud-core>=2.4.3 +google-cloud-core>=2.5.0 google-cloud-logging>=3.12.1 +google-cloud-mldiagnostics>=0.5.0 google-cloud-monitoring>=2.28.0 google-cloud-resource-manager>=1.15.0 -google-cloud-storage>=2.19.0 +google-cloud-storage>=3.4.1 google-crc32c>=1.7.1 -google-genai>=1.46.0 -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-genai>=1.48.0 google-pasta>=0.2.0 google-resumable-media>=2.7.2 google-tunix>=0.1.3 googleapis-common-protos>=1.71.0 -grain>=0.2.13 +grain>=0.2.14 grpc-google-iam-v1>=0.14.3 grpcio-status>=1.71.2 grpcio>=1.75.1 @@ -85,7 +85,7 @@ gviz-api>=1.10.0 h11>=0.16.0 h5py>=3.15.1 hf-transfer>=0.1.9 -hf-xet>=1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' +hf-xet>=1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' httpcore>=1.0.9 httplib2>=0.31.0 httpx>=0.28.1 @@ -112,7 +112,7 @@ kagglehub>=0.3.13 keras>=3.11.3 kiwisolver>=1.4.9 libclang>=18.1.1 -libcst>=1.8.5 +libcst>=1.8.6 libtpu>=0.0.24 ; platform_machine == 'x86_64' and sys_platform == 'linux' llvmlite>=0.45.1 lxml>=6.0.2 @@ -125,13 +125,12 @@ mdurl>=0.1.2 ml-collections>=1.1.0 ml-dtypes>=0.5.3 ml-goodput-measurement>=0.0.15 -mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip more-itertools>=10.8.0 mpmath>=1.3.0 msgpack>=1.1.2 msgspec>=0.19.0 multidict>=6.7.0 -multiprocess>=0.70.16 +multiprocess>=0.70.18 mypy-extensions>=1.1.0 namex>=0.1.0 nest-asyncio>=1.6.0 @@ -184,12 +183,12 @@ pyproject-hooks>=1.2.0 pytest-xdist>=3.8.0 pytest>=8.4.2 python-dateutil>=2.9.0.post0 -python-dotenv>=1.1.1 +python-dotenv>=1.2.1 pytype>=2024.10.11 pytz>=2025.2 pyyaml>=6.0.3 -qwix>=0.1.1 -regex>=2025.10.23 +qwix>=0.1.2 +regex>=2025.11.3 requests-oauthlib>=2.0.0 requests>=2.32.5 rich>=14.2.0 @@ -207,7 +206,7 @@ simplejson>=3.20.2 six>=1.17.0 sniffio>=1.3.1 sortedcontainers>=2.4.0 -starlette>=0.48.0 +starlette>=0.49.3 sympy>=1.14.0 tabulate>=0.9.0 tenacity>=9.1.2 @@ -238,14 +237,14 @@ tzdata>=2025.2 uritemplate>=4.2.0 urllib3>=2.5.0 uvicorn>=0.38.0 -virtualenv>=20.35.3 +virtualenv>=20.35.4 wadler-lindig>=0.1.7 websockets>=15.0.1 werkzeug>=3.1.3 wheel>=0.45.1 wrapt>=2.0.0 -xprof>=2.20.7 +xprof>=2.21.0 xxhash>=3.6.0 yarl>=1.22.0 zipp>=3.23.0 -zstandard>=0.25.0 \ No newline at end of file +zstandard>=0.25.0 diff --git a/dependencies/requirements/requirements.txt b/dependencies/requirements/requirements.txt index 36471cf55..8d89aee99 100644 --- a/dependencies/requirements/requirements.txt +++ b/dependencies/requirements/requirements.txt @@ -8,6 +8,7 @@ flax gcsfs google-api-python-client google-cloud-aiplatform +google-cloud-mldiagnostics google-cloud-monitoring grain[parquet] huggingface_hub diff --git a/dependencies/requirements/requirements_with_jax_ai_image.txt b/dependencies/requirements/requirements_with_jax_ai_image.txt index 993f0e6d8..757917b24 100644 --- a/dependencies/requirements/requirements_with_jax_ai_image.txt +++ b/dependencies/requirements/requirements_with_jax_ai_image.txt @@ -3,6 +3,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip flax>=0.11.0 google-api-python-client +google-cloud-mldiagnostics google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip grain[parquet]>=0.2.13 jaxtyping diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 599c3dfd7..4e6afe230 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -573,7 +573,7 @@ colocated_python_data_input: False # experimental feature, under testing # Training loop steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps -log_period: 100 # Flushes Tensorboard +log_period: 100 # The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating. jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers @@ -618,6 +618,12 @@ profile_cleanly: True # If set to true, adds a block_until_ready on train state profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps. # This is useful to debug scenarios where performance is changing. +# Managed ML diagnostics settings. If the feature is enabled, it will +# - create a managed ML diagnostics run with all the MaxText configs +# - upload xplane profiling, if it is enabled. +# - upload training metrics, at the defined log_period interval. +managed_mldiagnostics: False # Whether to enable the managed diagnostics +managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs. # Dump HLO options dump_hlo: False diff --git a/src/MaxText/managed_mldiagnostics.py b/src/MaxText/managed_mldiagnostics.py new file mode 100644 index 000000000..ce252ad83 --- /dev/null +++ b/src/MaxText/managed_mldiagnostics.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Create the managed mldiagnostics run.""" +from typing import Any + +import simplejson as json + +from google_cloud_mldiagnostics import machinelearning_run + +from MaxText.pyconfig import KEYS_NO_LOGGING + + +class ManagedMLDiagnostics: + """ + ML Diagnostics Run, implemented with the Singleton pattern. + Ensures that only one instance of can exist. + """ + + enabled = False + + _instance = None # Class attribute to hold the single instance + + def __new__(cls, *args: Any, **kwargs: Any): + """ + Overrides the instance creation method. + If an instance already exists, it is returned instead of creating a new one. + """ + if cls._instance is None: + cls._instance = super(ManagedMLDiagnostics, cls).__new__(cls) + + return cls._instance + + def __init__(self, config): + """ + Initializes the ManagedMLDiagnostics, ensuring this method runs only once. + """ + # We need a flag to ensure __init__ only runs once, + # as the object is returned multiple times by __new__. + if hasattr(self, "_initialized"): + return + self._initialized = True + self.enabled = config.managed_mldiagnostics + if not self.enabled: + return + + # Set up the managed mldiagnostics for profiling and metrics uploading. + def should_log_key(key, value): + if key in KEYS_NO_LOGGING: + return False + try: + # Verify the value can be serialized to json. If not, we'll skip it. + json.dumps(value, allow_nan=False) + except TypeError: + return False + return True + + config_dict = {key: value for key, value in config.get_keys().items() if should_log_key(key, value)} + + # Create a run for the managed mldiagnostics, and upload the configuration. + machinelearning_run( + # Use different names if on all devices. + name=f"{config.run_name}", + run_group=config.managed_mldiagnostics_run_group, + configs=config_dict, + gcs_path=config.managed_mldiagnostics_dir, + # TODO: b/455623960 - Remove the following once multi-region and prod support are enabled. + region="us-central1", + environment="autopush", # Default would be "prod" for formal launch. + ) diff --git a/src/MaxText/metric_logger.py b/src/MaxText/metric_logger.py index c3999a455..a1ec99d7c 100644 --- a/src/MaxText/metric_logger.py +++ b/src/MaxText/metric_logger.py @@ -25,15 +25,31 @@ import jax +from google_cloud_mldiagnostics import metrics as mlmetrics + from MaxText import max_logging from MaxText import max_utils from MaxText import maxtext_utils +from MaxText.managed_mldiagnostics import ManagedMLDiagnostics from MaxText.utils import gcs_utils from MaxText.gcp_workload_monitor import GCPWorkloadMonitor from MaxText.globals import EPS from collections import defaultdict +# Mapping MaxText metrics to managed profiler metrics +_METRICS_TO_MANAGED = { + "learning/current_learning_rate": "learning_rate", + "learning/loss": "loss", + "learning/grad_norm": "gradient_norm", + "learning/total_weights": "total_weights", + "perf/step_time_seconds": "step_time", + "perf/per_device_tokens_per_sec": "throughput", + "perf/per_device_tflops_per_sec": "tflops", + # There are no mappings to the following metrics yet: + # "latency", "mfu" +} + def _prepare_metrics_for_json(metrics, step, run_name): """Converts metric dictionary into json supported types (e.g. float)""" @@ -82,6 +98,7 @@ def __init__(self, config, learning_rate_schedule): self.learning_rate_schedule = learning_rate_schedule self.cumulative_eval_metrics = {"scalar": defaultdict(float)} self.buffered_train_metrics = None + self.managed_mldiagnostics = ManagedMLDiagnostics(config) def reset_eval_metrics(self): """Resets the cumulative metrics dictionary for a new evaluation run.""" @@ -101,6 +118,9 @@ def write_metrics(self, metrics, step, is_training=True): if self.config.gcs_metrics and jax.process_index() == 0: self.write_metrics_for_gcs(metrics, step, is_training) + if self.managed_mldiagnostics.enabled: + self.write_metrics_to_managed_mldiagnostics(metrics, step) + def log_metrics(self, metrics, step, is_training): """Logs metrics via max_logging.""" if is_training: @@ -183,6 +203,18 @@ def write_metrics_to_tensorboard(self, metrics, step, is_training): max_logging.log(f"To see full metrics 'tensorboard --logdir={self.config.tensorboard_dir}'") self.writer.flush() + def write_metrics_to_managed_mldiagnostics(self, metrics, step): + """Write metrics to managed profiler.""" + if (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1: + for metric_name in metrics.get("scalar", []): + value = metrics["scalar"][metric_name] + # For NumPy/JAX array objects (including single-element arrays), use .item() + # to extract the native Python scalar. + if hasattr(value, "item"): + value = value.item() + mapped_metric_name = _METRICS_TO_MANAGED.get(metric_name, metric_name) + mlmetrics.record(mapped_metric_name, value, step=int(step)) + def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" num_model_parameters = max_utils.calculate_num_params_from_pytree(params) diff --git a/src/MaxText/profiler.py b/src/MaxText/profiler.py index e32e49ff2..638ce8e92 100644 --- a/src/MaxText/profiler.py +++ b/src/MaxText/profiler.py @@ -21,7 +21,10 @@ import jax +from google_cloud_mldiagnostics import xprof + from MaxText import max_logging +from MaxText.managed_mldiagnostics import ManagedMLDiagnostics class Profiler: @@ -40,6 +43,8 @@ def __init__(self, config, offset_step=0): self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps) if config.profiler != "" and self.start_initial_profile_step >= config.steps: raise ValueError("Profiling requested but initial profiling step set past training final step") + self.managed_mldiagnostics = ManagedMLDiagnostics(config) + self.prof = None # managed mldiagnostics xprof collector. def maybe_activate_profiler(self, step, state): """Conditionally activates the profiler based on the current step. @@ -56,6 +61,16 @@ def activate(self, blocking_object=None, optional_postfix=""): nsys profiler becomes no-op when libcudart.so is not available on the system.""" if self.profile_cleanly and blocking_object is not None: jax.block_until_ready(blocking_object) + + if self.managed_mldiagnostics.enabled and self.mode == "xplane": + # Handle the special profiling logic for managed_mldiagnostics + if self.prof is None: + # Starts xprof collector. + # Only profiling on the first device, if not upload_all_profiler_results. None is for all devices. + self.prof = xprof(process_index_list=None if self.upload_all_profiler_results else [0]) + self.prof.start() + return + if not (self.upload_all_profiler_results or jax.process_index() == 0): return if self.mode != "": @@ -84,6 +99,13 @@ def deactivate(self, blocking_object=None): The result is uploaded to the output bucket.""" if self.profile_cleanly and blocking_object is not None: jax.block_until_ready(blocking_object) + + if self.managed_mldiagnostics and self.mode == "xplane": + # Handle the special profileing logic for managed_mldiagnostics + if self.prof is not None: + self.prof.stop() + return + if not (self.upload_all_profiler_results or jax.process_index() == 0): return if self.mode == "nsys": diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 177ea8bb2..cb07e4208 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -40,6 +40,9 @@ _MAX_PREFIX = "M_" +# Don't log the following keys. +KEYS_NO_LOGGING = ("hf_access_token",) + # YAML attribute to specify inheritance. _BASE_CONFIG_ATTR = "base_config" @@ -678,7 +681,7 @@ def __init__(self, argv: list[str], **kwargs): if raw_keys["log_config"]: for k in keys: - if k != "hf_access_token": + if k not in KEYS_NO_LOGGING: max_logging.log(f"Config param {k}: {raw_keys[k]}") @staticmethod @@ -696,6 +699,8 @@ def user_init(raw_keys): raw_keys["tensorboard_dir"] = os.path.join(base_output_directory, run_name, "tensorboard", "") raw_keys["checkpoint_dir"] = os.path.join(base_output_directory, run_name, "checkpoints", "") raw_keys["metrics_dir"] = os.path.join(base_output_directory, run_name, "metrics", "") + # To work around SDK bug b/454725283, remove the trailing back slash from the managed_mldiagnostics_dir. + raw_keys["managed_mldiagnostics_dir"] = os.path.join(base_output_directory, run_name, "managed-mldiagnostics") if raw_keys["learning_rate_schedule_steps"] == -1: raw_keys["learning_rate_schedule_steps"] = raw_keys["steps"] @@ -1219,7 +1224,7 @@ def validate_optimizer_sharding_over_data(raw_keys): zero1_supported_opt_types = ("adamw", "adam_pax") if raw_keys["opt_type"] not in zero1_supported_opt_types: raise ValueError( - f"Optimizer type {raw_keys["opt_type"]} is not supported for optimizer sharding.\n" + f"Optimizer type {raw_keys['opt_type']} is not supported for optimizer sharding.\n" f"Please use an optimizer from this list: {zero1_supported_opt_types}." )