Skip to content

Adding basic elastic training (pause-and-resume) #1256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4f3c12f
Updating Pathways-TPU integration with AxLearn
lukebaumann Jun 16, 2025
f9fb40f
Adding the exclusive topology by hostname for the head job
lukebaumann Jul 14, 2025
489a775
Adding a model config for 7B model and v5litepod-32-2
lukebaumann Jul 10, 2025
42109b9
Turning off shuffling for datasets to decrease initialization time
lukebaumann Jul 14, 2025
12631f3
Removing memory and cpu limits on tpus because our cluster kueues do …
lukebaumann Jul 15, 2025
ef4b86a
Adding a patch to shard_map.py in JAX. This is to fix a bug in JAX 0.…
lukebaumann Jul 15, 2025
0b2e16b
Increase pathways head memory and CPU
shauryagup Jul 14, 2025
33c96a1
Add 150B config and shard_map patch
shauryagup Jul 16, 2025
a7e0f51
Move to a more recent version of pathways images since the default on…
shauryagup Jul 16, 2025
bd04c4d
Remove the batch_size=len(jax.devices()) workaround as it is not need…
shauryagup Jul 16, 2025
de38280
Add orbax checkpointing
shauryagup Jul 17, 2025
aabc337
Fixed the config for th 7B v5litepod-32-2 config
lukebaumann Jul 17, 2025
924db3c
Integrate all new Goodput features in v12
dipannita08 Jul 15, 2025
3fcb867
Revert "Integrate all new Goodput features in v12"
RoshaniN Jul 24, 2025
287ee81
Integrate AXLearn with Goodput v12
dipannita08 Jul 23, 2025
fcb3a7a
Increase mem and cpu further.
RoshaniN Jul 24, 2025
2c9c105
Adding RM logs by flag --alsologtostderr
lukebaumann Jul 24, 2025
ada6783
Removing resource limits for the controller container on the head pod
lukebaumann Jul 23, 2025
b52dfd8
Adding a pipe break missing flag that makes the cluster fail quickly
lukebaumann Jul 22, 2025
253d77f
Adding an extra slice to the Pathways cluster to swap in when there i…
lukebaumann Jul 23, 2025
9f4cb70
Adding basic elastic training
lukebaumann Jul 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean
RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi
COPY . .

################################################################################
# Pathways-TPU container spec. #
################################################################################

FROM base AS pathways-tpu

ARG EXTRAS=

RUN uv pip install --prerelease=allow .[core,pathways-tpu] && uv cache clean
RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi
COPY . .

################################################################################
# GPU container spec. #
################################################################################
Expand Down
13 changes: 1 addition & 12 deletions axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,6 @@ def _build_container(self) -> Nested[Any]:
env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower()

resources = {"limits": {"google.com/tpu": system.chips_per_vm}}
# Set request memory by host machine type.
machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
system.gce_machine_type, None
)
if machine_memory_gi is not None:
request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}

k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()]
k8s_env_vars.append(
Expand Down Expand Up @@ -508,10 +500,7 @@ def _build_uploader_container(
dst = f"{cfg.output_dir}/output/$HOSTNAME/"
interval_s = 60
sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done"
resources = {
"requests": {"cpu": "100m", "memory": "128Mi"},
"limits": {"cpu": "500m", "memory": "256Mi"},
}
resources = {}
return dict(
name="output-uploader",
image="google/cloud-sdk:alpine",
Expand Down
196 changes: 124 additions & 72 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

"""Measurement utils for GCP.
For detailed documentation and advanced usage, please refer to:
axlearn/docs/05-Goodput-Monitoring.md
Example:
# Enable Goodput when launching an AXLearn training job
Expand All @@ -13,10 +16,14 @@
--recorder_spec=name=my-run-with-goodput \
--recorder_spec=upload_dir=my-output-directory/summaries \
--recorder_spec=upload_interval=30 \
--recorder_spec=step_deviation_interval_seconds=30
--recorder_spec=rolling_window_size=86400,259200,432000
"""

import contextlib
import os
from typing import Optional, Sequence

import jax
from absl import flags, logging
from ml_goodput_measurement import goodput
Expand All @@ -38,13 +45,19 @@ class Config(measurement.Recorder.Config):
Attributes:
upload_dir: Directory to store metrics for the monitor.
upload_interval: Time interval (seconds) for monitoring uploads.
step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
uploads. -1 to disable step deviation uploads.
See "How to Monitor Cumulative Goodput Metrics" in
docs/05-Goodput-Monitoring.md for more details.
rolling_window_size: A sequence of integers defining the rolling window sizes in
seconds.
See "How to Monitor Rolling Window Goodput Metrics" in
docs/05-Goodput-Monitoring.md for more details.
jax_backend: Jax backend type to infer Pathways environment.
"""

upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED
step_deviation_interval_seconds: int = 30 # Default to 30 seconds
rolling_window_size: Sequence[int] = []
jax_backend: Optional[str] = None

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
Expand All @@ -53,68 +66,78 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
additionally take in following Tensorboard configs in the recorder_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
to Tensorboard.
- step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
uploads. Set to less than or equal to 0 to disable step deviation uploads.
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
to Tensorboard.
- rolling_window_size: Comma-separated list of integers representing rolling window
sizes in seconds.
- jax_backend: The type of jax backend.
"""
cfg: measurement.Recorder.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
return cfg.instantiate()
parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=")
if "upload_interval" in parsed_flags:
parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"])
if "rolling_window_size" in parsed_flags and isinstance(
parsed_flags["rolling_window_size"], str
):
parsed_flags["rolling_window_size"] = [
int(x) for x in parsed_flags["rolling_window_size"].split(",")
]
return maybe_set_config(cfg, **parsed_flags).instantiate()

def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputRecorder.Config = self.config
self._recorder = None
self._monitor = None

def record(self, event: measurement.Event, *args, **kwargs):
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
self._recorder: Optional[goodput.GoodputRecorder] = None
self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None
self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None
self._job_name = cfg.name
self._logger_name = f"goodput_logger_{cfg.name}"

@contextlib.contextmanager
def record_event(self, event: measurement.Event, *args, **kwargs):
"""Records a goodput event using a context manager."""
# Lazily instantiate the recorder if it hasn't been already.
if self._recorder is None:
cfg: GoodputRecorder.Config = self.config
if jax.process_index() == 0:
logging.info("Lazily instantiating goodput recorder.")
self._recorder = goodput.GoodputRecorder(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
job_name=self._job_name,
logger_name=self._logger_name,
logging_enabled=(jax.process_index() == 0),
)

if event == measurement.Event.START_JOB:
self._recorder.record_job_start_time(*args, **kwargs)
elif event == measurement.Event.END_JOB:
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT:
self._recorder.record_custom_badput_event_start_time(*args, **kwargs)
elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT:
self._recorder.record_custom_badput_event_end_time(*args, **kwargs)
else:
logging.log_first_n(
logging.WARNING,
"Ignoring unknown event %s",
1,
event,
start_method_name = f"record_{event.value}_start_time"
end_method_name = f"record_{event.value}_end_time"

record_event_start = getattr(self._recorder, start_method_name, None)
record_event_end = getattr(self._recorder, end_method_name, None)

try:
if record_event_start:
record_event_start(*args, **kwargs)
except RuntimeError as e:
logging.warning(
"Failed to record start of event %s. Error: %s", event.value, e, exc_info=True
)

def start_monitoring(self, *args, **kwargs):
"""Starts Monitoring of Goodput.
try:
yield
finally:
try:
if record_event_end:
record_event_end(*args, **kwargs)
except RuntimeError as e:
logging.warning(
"Failed to record end of event %s. Error: %s", event.value, e, exc_info=True
)

@contextlib.contextmanager
def maybe_monitor_goodput(self, *args, **kwargs):
"""Monitor cumulative goodput if enabled.
Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
directory.
Goodput, Badput, Step & Disruption Information at the upload_interval to the
specified TensorBoard directory and Google Cloud Monitoring.
Note: This function requires initialization of distributed JAX before it is called.
If there are internal GCP errors from querying and uploading data, these will be
logged without affecting the workload. GoodputMonitor logs will provide further
Expand All @@ -123,33 +146,62 @@ def start_monitoring(self, *args, **kwargs):
Default behavior is to push metrics to Google Cloud Monitoring.
This behavior can be overridden by configuring `goodput_monitoring.GCPOptions`
"""
cfg: GoodputRecorder.Config = self.config
include_step_deviation = True
if jax.process_index() == 0:
if jax.process_index() != 0:
yield
return
try:
if self._monitor is None:
if int(cfg.step_deviation_interval_seconds) <= 0:
include_step_deviation = False

gcp_options = goodput_monitoring.GCPOptions(
enable_gcp_goodput_metrics=True,
enable_gcp_step_deviation_metrics=include_step_deviation,
)
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
job_name=self._job_name,
logger_name=self._logger_name,
tensorboard_dir=self.config.upload_dir,
upload_interval=self.config.upload_interval,
monitoring_enabled=True,
pathway_enabled=self.config.jax_backend == "proxy",
include_badput_breakdown=True,
include_step_deviation=include_step_deviation,
step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds),
gcp_options=gcp_options,
)

self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard & GCM in the background!")
if include_step_deviation:
self._monitor.start_step_deviation_uploader(*args, **kwargs)
yield
finally:
if self._monitor:
self._monitor.stop_goodput_uploader()
logging.info("Flushed final metrics and safe exited from Goodput monitoring.")

@contextlib.contextmanager
def maybe_monitor_rolling_window_goodput(self):
"""Monitor rolling window goodput if enabled."""
if not self.config.rolling_window_size or jax.process_index() != 0:
yield
return
try:
if self._rolling_window_monitor is None:
rolling_window_tensorboard_dir = os.path.join(
self.config.upload_dir, f"rolling_window_{self.config.name}"
)
self._rolling_window_monitor = goodput_monitoring.GoodputMonitor(
job_name=self._job_name,
logger_name=self._logger_name,
tensorboard_dir=rolling_window_tensorboard_dir,
upload_interval=self.config.upload_interval,
monitoring_enabled=True,
pathway_enabled=self.config.jax_backend == "proxy",
include_badput_breakdown=True,
)
self._rolling_window_monitor.start_rolling_window_goodput_uploader(
self.config.rolling_window_size
)
logging.info("Started Rolling Window Goodput monitoring in the background!")
yield
finally:
if self._rolling_window_monitor:
self._rolling_window_monitor.stop_rolling_window_goodput_uploader()
logging.info(
"Started Step Deviation upload to Tensorboard & GCM in the background!"
"Flushed final metrics and safe exited from Rolling Window Goodput monitoring."
)


def create_goodput_recorder(cfg: GoodputRecorder.Config):
"""Factory method to create GoodputRecorder."""
return GoodputRecorder(cfg)
Loading