diff --git a/Dockerfile b/Dockerfile index 29db664d3..61a5bf866 100644 --- a/Dockerfile +++ b/Dockerfile @@ -92,6 +92,20 @@ RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi COPY . . +################################################################################ +# Pathways-TPU container spec. # +################################################################################ + +FROM base AS pathways-tpu + +ARG EXTRAS= + +RUN uv pip install debugpy && uv cache clean +EXPOSE 5678 +RUN uv pip install --prerelease=allow .[core,pathways-tpu] && uv cache clean +RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi +COPY . . + ################################################################################ # GPU container spec. # ################################################################################ diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index ab3a7daaf..c2a988692 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -452,14 +452,6 @@ def _build_container(self) -> Nested[Any]: env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower() resources = {"limits": {"google.com/tpu": system.chips_per_vm}} - # Set request memory by host machine type. - machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get( - system.gce_machine_type, None - ) - if machine_memory_gi is not None: - request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE - resources["limits"]["memory"] = f"{machine_memory_gi}Gi" - resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"} k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()] k8s_env_vars.append( @@ -479,6 +471,7 @@ def _build_container(self) -> Nested[Any]: dict(containerPort=8080), # Port for MXLA coordinator. dict(containerPort=8431), # Port to export TPU runtime metrics. dict(containerPort=self._load_balancer.target_port), # Port for load balancer. + dict(containerPort=5678), # Port for debugger. ], securityContext=dict(privileged=True), # TODO(markblee): Improve SIGTERM behavior for command. @@ -508,10 +501,7 @@ def _build_uploader_container( dst = f"{cfg.output_dir}/output/$HOSTNAME/" interval_s = 60 sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done" - resources = { - "requests": {"cpu": "100m", "memory": "128Mi"}, - "limits": {"cpu": "500m", "memory": "256Mi"}, - } + resources = {} return dict( name="output-uploader", image="google/cloud-sdk:alpine", diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index 0d4ce0069..849f8983d 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -2,6 +2,9 @@ """Measurement utils for GCP. + For detailed documentation and advanced usage, please refer to: + axlearn/docs/05-Goodput-Monitoring.md + Example: # Enable Goodput when launching an AXLearn training job @@ -13,10 +16,14 @@ --recorder_spec=name=my-run-with-goodput \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 + --recorder_spec=rolling_window_size=86400,604800 """ +import contextlib +import os +from typing import Optional, Sequence + import jax from absl import flags, logging from ml_goodput_measurement import goodput @@ -38,13 +45,19 @@ class Config(measurement.Recorder.Config): Attributes: upload_dir: Directory to store metrics for the monitor. upload_interval: Time interval (seconds) for monitoring uploads. - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics - uploads. -1 to disable step deviation uploads. + See "How to Monitor Cumulative Goodput Metrics" in + docs/05-Goodput-Monitoring.md for more details. + rolling_window_size: A sequence of integers defining the rolling window sizes in + seconds. + See "How to Monitor Rolling Window Goodput Metrics" in + docs/05-Goodput-Monitoring.md for more details. + jax_backend: Jax backend type to infer Pathways environment. """ upload_dir: Required[str] = REQUIRED upload_interval: Required[int] = REQUIRED - step_deviation_interval_seconds: int = 30 # Default to 30 seconds + rolling_window_size: Sequence[int] = [] + jax_backend: Optional[str] = None @classmethod def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": @@ -53,68 +66,81 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": `fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names corresponding to keys will be set to the corresponding values. A GoodputRecorder can additionally take in following Tensorboard configs in the recorder_spec: - - upload_dir: The directory to write Tensorboard data to. - - upload_interval: The time interval in seconds at which to query and upload data - to Tensorboard. - - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics - uploads. Set to less than or equal to 0 to disable step deviation uploads. + - upload_dir: The directory to write Tensorboard data to. + - upload_interval: The time interval in seconds at which to query and upload data + to Tensorboard. + - rolling_window_size: Comma-separated list of integers representing rolling window + sizes in seconds. + - jax_backend: The type of jax backend. """ cfg: measurement.Recorder.Config = cls.default_config() - cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="=")) - return cfg.instantiate() + parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=") + if "upload_interval" in parsed_flags: + parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"]) + if "rolling_window_size" in parsed_flags and isinstance( + parsed_flags["rolling_window_size"], str + ): + parsed_flags["rolling_window_size"] = [ + int(x) for x in parsed_flags["rolling_window_size"].split(",") + ] + return maybe_set_config(cfg, **parsed_flags).instantiate() def __init__(self, cfg): super().__init__(cfg) - cfg: GoodputRecorder.Config = self.config - self._recorder = None - self._monitor = None - - def record(self, event: measurement.Event, *args, **kwargs): - # Lazily instantiate the recorder. This avoids invoking jax before setup is complete. + self._recorder: Optional[goodput.GoodputRecorder] = None + self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None + self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None + self._job_name = cfg.name + self._logger_name = f"goodput_logger_{cfg.name}" + + @contextlib.contextmanager + def record_event(self, event: measurement.EventType, *args, **kwargs): + """Records a goodput event using a context manager.""" + # Lazily instantiate the recorder if it hasn't been already. if self._recorder is None: - cfg: GoodputRecorder.Config = self.config + if jax.process_index() == 0: + logging.info("Lazily instantiating goodput recorder.") self._recorder = goodput.GoodputRecorder( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", + job_name=self._job_name, + logger_name=self._logger_name, logging_enabled=(jax.process_index() == 0), ) - if event == measurement.Event.START_JOB: - self._recorder.record_job_start_time(*args, **kwargs) - elif event == measurement.Event.END_JOB: - self._recorder.record_job_end_time(*args, **kwargs) - elif event == measurement.Event.START_STEP: - self._recorder.record_step_start_time(*args, **kwargs) - elif event == measurement.Event.START_ACCELERATOR_INIT: - self._recorder.record_tpu_init_start_time(*args, **kwargs) - elif event == measurement.Event.END_ACCELERATOR_INIT: - self._recorder.record_tpu_init_end_time(*args, **kwargs) - elif event == measurement.Event.START_TRAINING_PREPARATION: - self._recorder.record_training_preparation_start_time(*args, **kwargs) - elif event == measurement.Event.END_TRAINING_PREPARATION: - self._recorder.record_training_preparation_end_time(*args, **kwargs) - elif event == measurement.Event.START_DATA_LOADING: - self._recorder.record_data_loading_start_time(*args, **kwargs) - elif event == measurement.Event.END_DATA_LOADING: - self._recorder.record_data_loading_end_time(*args, **kwargs) - elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_start_time(*args, **kwargs) - elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_end_time(*args, **kwargs) - else: - logging.log_first_n( - logging.WARNING, - "Ignoring unknown event %s", - 1, - event, - ) + start_method_name = f"record_{event.value}_start_time" + end_method_name = f"record_{event.value}_end_time" - def start_monitoring(self, *args, **kwargs): - """Starts Monitoring of Goodput. + record_event_start = getattr(self._recorder, start_method_name, None) + record_event_end = getattr(self._recorder, end_method_name, None) + + if record_event_start: + try: + record_event_start(*args, **kwargs) + except RuntimeError as e: + logging.warning( + "Failed to record start of event %s. Error: %s", event.value, e, exc_info=True + ) + # pylint: disable=try-except-raise + try: + yield # Run the user code in the context + except Exception: + raise + else: + if record_event_end: + try: + record_event_end(*args, **kwargs) + except RuntimeError as e: + logging.warning( + "Failed to record end of event %s. Error: %s", event.value, e, exc_info=True + ) + # pylint: enable=try-except-raise + + @contextlib.contextmanager + def _maybe_monitor_goodput(self, *args, **kwargs): + """Monitor cumulative goodput if enabled. Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate - Goodput and Badput at the upload_interval and upload to the specified TensorBoard - directory. + Goodput, Badput, Step & Disruption Information at the upload_interval to the + specified TensorBoard directory and Google Cloud Monitoring. Note: This function requires initialization of distributed JAX before it is called. If there are internal GCP errors from querying and uploading data, these will be logged without affecting the workload. GoodputMonitor logs will provide further @@ -123,33 +149,68 @@ def start_monitoring(self, *args, **kwargs): Default behavior is to push metrics to Google Cloud Monitoring. This behavior can be overridden by configuring `goodput_monitoring.GCPOptions` """ - cfg: GoodputRecorder.Config = self.config - include_step_deviation = True - if jax.process_index() == 0: + if jax.process_index() != 0: + yield + return + try: if self._monitor is None: - if int(cfg.step_deviation_interval_seconds) <= 0: - include_step_deviation = False - - gcp_options = goodput_monitoring.GCPOptions( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=include_step_deviation, - ) self._monitor = goodput_monitoring.GoodputMonitor( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", - tensorboard_dir=cfg.upload_dir, - upload_interval=int(cfg.upload_interval), + job_name=self._job_name, + logger_name=self._logger_name, + tensorboard_dir=self.config.upload_dir, + upload_interval=self.config.upload_interval, monitoring_enabled=True, + pathway_enabled=self.config.jax_backend == "proxy", include_badput_breakdown=True, - include_step_deviation=include_step_deviation, - step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds), - gcp_options=gcp_options, ) self._monitor.start_goodput_uploader(*args, **kwargs) logging.info("Started Goodput upload to Tensorboard & GCM in the background!") - if include_step_deviation: - self._monitor.start_step_deviation_uploader(*args, **kwargs) + yield + finally: + if self._monitor: + self._monitor.stop_goodput_uploader() + logging.info("Flushed final metrics and safe exited from Goodput monitoring.") + + @contextlib.contextmanager + def _maybe_monitor_rolling_window_goodput(self): + """Monitor rolling window goodput if enabled.""" + if not self.config.rolling_window_size or jax.process_index() != 0: + yield + return + try: + if self._rolling_window_monitor is None: + rolling_window_tensorboard_dir = os.path.join( + self.config.upload_dir, f"rolling_window_{self.config.name}" + ) + self._rolling_window_monitor = goodput_monitoring.GoodputMonitor( + job_name=self._job_name, + logger_name=self._logger_name, + tensorboard_dir=rolling_window_tensorboard_dir, + upload_interval=self.config.upload_interval, + monitoring_enabled=True, + pathway_enabled=self.config.jax_backend == "proxy", + include_badput_breakdown=True, + ) + self._rolling_window_monitor.start_rolling_window_goodput_uploader( + self.config.rolling_window_size + ) + logging.info("Started Rolling Window Goodput monitoring in the background!") + yield + finally: + if self._rolling_window_monitor: + self._rolling_window_monitor.stop_rolling_window_goodput_uploader() logging.info( - "Started Step Deviation upload to Tensorboard & GCM in the background!" + "Flushed final metrics and safe exited from Rolling Window Goodput monitoring." ) + + def maybe_monitor_all_goodput(self): + goodput_monitor_manager = self._maybe_monitor_goodput() + rolling_goodput_monitor_manager = self._maybe_monitor_rolling_window_goodput() + + @contextlib.contextmanager + def monitor_goodput(): + with goodput_monitor_manager, rolling_goodput_monitor_manager: + yield + + return monitor_goodput() diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index e14fc16c4..f5ec1ddc4 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -3,191 +3,373 @@ """Tests measurement utils for GCP.""" # pylint: disable=protected-access -import contextlib from unittest import mock -from absl import flags +from absl import flags, logging from absl.testing import parameterized from axlearn.cloud.gcp.measurement import GoodputRecorder from axlearn.common import measurement +from axlearn.common.config import RequiredFieldMissingError class GoodputRecorderTest(parameterized.TestCase): """Tests GoodputRecorder.""" @parameterized.parameters( - (None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],) - ) - def test_from_flags(self, spec): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - if spec is not None: - fv.set_default("recorder_spec", spec) - fv.mark_as_parsed() - - if spec is None: - ctx = self.assertRaisesRegex(ValueError, "name") - else: - ctx = contextlib.nullcontext() - - with ctx: - recorder = GoodputRecorder.from_flags(fv) - # Recorder is not instantiated until first event. - self.assertIsNone(recorder._recorder) - - def test_record_and_monitor(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], - ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - recorder._recorder = mock.MagicMock() - recorder.record(measurement.Event.START_JOB) - self.assertTrue(recorder._recorder.record_job_start_time.called) - - def test_start_goodput_monitoring(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - [ + dict( + recorder_spec=[ "name=test-name", - "upload_dir=/test/path/to/upload", + "upload_dir=/test/path", "upload_interval=15", - "step_deviation_interval_seconds=-1", ], - ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: - with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: - mock_monitor_instance = mock_goodput_monitor.return_value - recorder.start_monitoring() - mock_gcp_options.assert_called_once_with( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=False, - ) - mock_gcp_options_instance = mock_gcp_options.return_value - - # Check that GoodputMonitor was instantiated - mock_goodput_monitor.assert_called_once_with( - job_name="test-name", - logger_name="goodput_logger_test-name", - tensorboard_dir="/test/path/to/upload", - upload_interval=15, - monitoring_enabled=True, - include_badput_breakdown=True, - include_step_deviation=False, - step_deviation_interval_seconds=-1, - gcp_options=mock_gcp_options_instance, - ) - - # Ensure that start_goodput_uploader is called on the monitor instance - mock_monitor_instance.start_goodput_uploader.assert_called_once() - self.assertIsNotNone(recorder._monitor) - - def test_start_goodput_and_step_deviation_monitoring(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - [ + expected_rolling_window_size=[], + expected_jax_backend=None, + ), + dict( + recorder_spec=[ "name=test-name", - "upload_dir=/test/path/to/upload", + "upload_dir=/test/path", "upload_interval=15", - "step_deviation_interval_seconds=30", + "rolling_window_size=1,2,3", + "jax_backend=proxy", ], + expected_rolling_window_size=[1, 2, 3], + expected_jax_backend="proxy", + ), + ) + def test_from_flags( + self, + recorder_spec, + expected_rolling_window_size, + expected_jax_backend, + ): + """Tests that flags are correctly parsed into the config.""" + mock_fv = mock.MagicMock(spec=flags.FlagValues) + mock_fv.recorder_spec = recorder_spec + mock_fv.jax_backend = "tpu" + + recorder = GoodputRecorder.from_flags(mock_fv) + + self.assertEqual("test-name", recorder.config.name) + self.assertEqual("/test/path", recorder.config.upload_dir) + self.assertEqual(15, recorder.config.upload_interval) + self.assertEqual(expected_rolling_window_size, recorder.config.rolling_window_size) + self.assertEqual(expected_jax_backend, recorder.config.jax_backend) + + def test_from_flags_missing_required(self): + """Tests that missing required flags raise an error.""" + mock_fv = mock.MagicMock(spec=flags.FlagValues) + mock_fv.recorder_spec = ["name=test-name"] # Missing upload_dir/interval + mock_fv.jax_backend = "tpu" + with self.assertRaisesRegex(RequiredFieldMissingError, "upload_dir"): + GoodputRecorder.from_flags(mock_fv) + + @parameterized.parameters( + dict( + event=measurement.EventType.JOB, + expected_start="record_job_start_time", + expected_end="record_job_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.EventType.STEP, + expected_start="record_step_start_time", + expected_end=None, + args=(123,), + kwargs={}, + expect_end_call=False, + ), + dict( + event=measurement.EventType.ACCELERATOR_INIT, + expected_start="record_tpu_init_start_time", + expected_end="record_tpu_init_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.EventType.TRAINING_PREPARATION, + expected_start="record_training_preparation_start_time", + expected_end="record_training_preparation_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.EventType.DATA_LOADING, + expected_start="record_data_loading_start_time", + expected_end="record_data_loading_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.EventType.CUSTOM_BADPUT_EVENT, + expected_start="record_custom_badput_event_start_time", + expected_end="record_custom_badput_event_end_time", + args=(), + kwargs={"custom_badput_event_type": "TEST_TYPE"}, + expect_end_call=True, + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_record_event_context_manager_success( + self, _, event, expected_start, expected_end, args, kwargs, expect_end_call + ): + """Tests that record_event calls correct start and end methods with args and kwargs.""" + cfg = GoodputRecorder.default_config().set( + name="test", + upload_dir="/tmp/test", + upload_interval=1, ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: - with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: - mock_monitor_instance = mock_goodput_monitor.return_value - recorder.start_monitoring() - mock_gcp_options.assert_called_once_with( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=True, - ) - mock_gcp_options_instance = mock_gcp_options.return_value - - # Check that GoodputMonitor was instantiated - mock_goodput_monitor.assert_called_once_with( - job_name="test-name", - logger_name="goodput_logger_test-name", - tensorboard_dir="/test/path/to/upload", - upload_interval=15, - monitoring_enabled=True, - include_badput_breakdown=True, - include_step_deviation=True, - step_deviation_interval_seconds=30, - gcp_options=mock_gcp_options_instance, - ) + recorder = GoodputRecorder(cfg) - # Ensure that start_goodput_uploader and start_step_deviation_uploader is called on - # the monitor instance - mock_monitor_instance.start_goodput_uploader.assert_called_once() - mock_monitor_instance.start_step_deviation_uploader.assert_called_once() - self.assertIsNotNone(recorder._monitor) - - def test_missing_required_flags(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - # Missing 'upload_dir' and 'upload_interval' from recorder_spec - fv.set_default("recorder_spec", ["name=test-name"]) # Incomplete config - fv.mark_as_parsed() - - # Expecting ValueError since 'upload_dir' and 'upload_interval' are required - with self.assertRaises(ValueError): - GoodputRecorder.from_flags(fv) - - def test_monitoring_initialization_failure(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + with mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls: + mock_instance = mock_recorder_cls.return_value + + start_mock = mock.MagicMock() + setattr(mock_instance, expected_start, start_mock) + if expect_end_call and expected_end: + end_mock = mock.MagicMock() + setattr(mock_instance, expected_end, end_mock) + + with recorder.record_event(event, *args, **kwargs): + pass + + mock_recorder_cls.assert_called_once() + start_mock.assert_called_once_with(*args, **kwargs) + if expect_end_call and expected_end: + end_mock.assert_called_once_with(*args, **kwargs) + + def test_record_event_context_manager_handles_runtime_error(self): + cfg = GoodputRecorder.default_config().set( + name="test", + upload_dir="/tmp/test", + upload_interval=1, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch("jax.process_index", return_value=0): + with mock.patch( + "ml_goodput_measurement.goodput.GoodputRecorder" + ) as mock_recorder_cls, mock.patch.object(logging, "warning") as mock_warning: + mock_instance = mock_recorder_cls.return_value + + def raise_runtime_error(*args, **kwargs): + raise RuntimeError("mocked error") + + mock_instance.record_job_start_time.side_effect = raise_runtime_error + mock_instance.record_job_end_time.side_effect = raise_runtime_error + # Should not crash here. + with recorder.record_event(measurement.EventType.JOB): + pass + + # Assert warnings were logged for start and end failures + assert mock_warning.call_count == 2 + start_call = mock_warning.call_args_list[0] + end_call = mock_warning.call_args_list[1] + + assert "Failed to record" in start_call.args[0] + assert "Failed to record" in end_call.args[0] + + @parameterized.parameters( + dict(is_pathways_job=False, mock_jax_backend="tpu"), + dict(is_pathways_job=True, mock_jax_backend="proxy"), + dict(is_pathways_job=False, mock_jax_backend=None), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_goodput(self, _, is_pathways_job, mock_jax_backend): + """Tests the _maybe_monitor_goodput context manager.""" + cfg = GoodputRecorder.default_config().set( + name="test-monitor", + upload_dir="/test", + upload_interval=30, + jax_backend=mock_jax_backend, ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) - - # Mock a failure in initializing the GoodputMonitor - with mock.patch( - "ml_goodput_measurement.monitoring.GoodputMonitor", - side_effect=Exception("Failed to initialize GoodputMonitor"), - ): - with self.assertRaises(Exception): - recorder.start_monitoring() - self.assertIsNone(recorder._monitor) - - def test_non_zero_process_index(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + recorder = GoodputRecorder(cfg) + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + with recorder._maybe_monitor_goodput(): + pass + + # Verify that GoodputMonitor was instantiated with the correct parameters. + mock_monitor_cls.assert_called_once_with( + job_name="test-monitor", + logger_name="goodput_logger_test-monitor", + tensorboard_dir="/test", + upload_interval=30, + monitoring_enabled=True, + pathway_enabled=is_pathways_job, + include_badput_breakdown=True, + ) + mock_monitor_instance.start_goodput_uploader.assert_called_once() + mock_monitor_instance.stop_goodput_uploader.assert_called_once() + + @parameterized.parameters( + dict( + is_rolling_window_enabled=True, + rolling_window_size=[10, 20], + is_pathways_job=False, + mock_jax_backend="tpu", + ), + dict( + is_rolling_window_enabled=False, + rolling_window_size=[], + is_pathways_job=False, + mock_jax_backend="tpu", + ), + dict( + is_rolling_window_enabled=True, + rolling_window_size=[50], + is_pathways_job=True, + mock_jax_backend="proxy", + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_rolling_window( + self, + mock_process_index, + is_rolling_window_enabled, + rolling_window_size, + is_pathways_job, + mock_jax_backend, + ): # pylint: disable=unused-argument + """Tests the rolling window monitoring.""" + cfg = GoodputRecorder.default_config().set( + name="test-rolling", + upload_dir="/test", + upload_interval=30, + rolling_window_size=rolling_window_size, + jax_backend=mock_jax_backend, ) - fv.mark_as_parsed() + recorder = GoodputRecorder(cfg) - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + if not is_rolling_window_enabled: + with recorder._maybe_monitor_rolling_window_goodput(): + pass + mock_monitor_cls.assert_not_called() + return + with recorder._maybe_monitor_rolling_window_goodput(): + pass - with mock.patch("jax.process_index") as mock_process_index: - mock_process_index.return_value = 1 # Simulate a non-zero process index + mock_monitor_cls.assert_called_once_with( + job_name="test-rolling", + logger_name="goodput_logger_test-rolling", + tensorboard_dir="/test/rolling_window_test-rolling", + upload_interval=30, + monitoring_enabled=True, + pathway_enabled=is_pathways_job, + include_badput_breakdown=True, + ) + + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_with( + rolling_window_size + ) + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() + + @mock.patch("jax.process_index", return_value=1) + def test_non_zero_process_index_skips_monitoring( + self, mock_process_index + ): # pylint: disable=unused-argument + """Tests that monitoring is skipped on non-zero process indices.""" + cfg = GoodputRecorder.default_config().set( + name="test", upload_dir="/test", upload_interval=30 + ) + recorder = GoodputRecorder(cfg) - try: - recorder.start_monitoring() - except AttributeError: - self.fail("AttributeError was raised unexpectedly.") + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + # Test cumulative goodput monitoring. + with recorder._maybe_monitor_goodput(): + pass + mock_monitor_cls.assert_not_called() + + cfg_rolling = GoodputRecorder.default_config().set( + name="test-rolling-skip", + upload_dir="/test", + upload_interval=30, + rolling_window_size=[10, 20], + ) + recorder_rolling = GoodputRecorder(cfg_rolling) + with recorder_rolling._maybe_monitor_rolling_window_goodput(): + pass + mock_monitor_cls.assert_not_called() + + @parameterized.parameters( + dict( + rolling_window_size=[5, 10], + jax_backend="tpu", + expected_monitor_calls=2, # Cumulative & Rolling Window + expect_rolling=True, + expect_cumulative=True, + ), + dict( + rolling_window_size=[], + jax_backend="tpu", + expected_monitor_calls=1, # Cumulative only + expect_rolling=False, + expect_cumulative=True, + ), + dict( + rolling_window_size=[5, 10], + jax_backend=None, # Disables Pathways + expected_monitor_calls=2, + expect_rolling=True, + expect_cumulative=True, + ), + dict( + rolling_window_size=[], + jax_backend=None, + expected_monitor_calls=1, + expect_rolling=False, + expect_cumulative=True, + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_all_goodput( + self, + _, + rolling_window_size, + jax_backend, + expected_monitor_calls, + expect_rolling, + expect_cumulative, + ): + """Tests all goodput monitoring with various configs.""" + cfg = GoodputRecorder.default_config().set( + name="test-all", + upload_dir="/test", + upload_interval=30, + rolling_window_size=rolling_window_size, + jax_backend=jax_backend, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + + with recorder.maybe_monitor_all_goodput(): + pass + + self.assertEqual(mock_monitor_cls.call_count, expected_monitor_calls) + + if expect_cumulative: + mock_monitor_instance.start_goodput_uploader.assert_called_once() + mock_monitor_instance.stop_goodput_uploader.assert_called_once() + else: + mock_monitor_instance.start_goodput_uploader.assert_not_called() + mock_monitor_instance.stop_goodput_uploader.assert_not_called() + + if expect_rolling: + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_once_with( + rolling_window_size + ) + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() + else: + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called() + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called() diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index e3fd93420..89d379b55 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -41,7 +41,7 @@ # Pin to specific pathways image version for stable release. # There is no guarantee that this image will work with newer Jax releases. # However this image was also tested in Maxtext with Jax 0.6.1. -_PATHWAYS_IMAGE_TAG = "jax-0.5.3-patch060625" +_PATHWAYS_IMAGE_TAG = "jax-0.6.2" # The docker image used by pathways proxy container. _PATHWAYS_PROXY_IMAGE = ( f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}" @@ -135,14 +135,10 @@ class Config(BaseReplicatedJob.Config): Attributes: inner: The wrapped TPUReplicatedJob configuration. - pathways_head_cpu: CPU request for pathways-head container. - pathways_head_mem: Memory request for pathways-head container. """ inner: Required[TPUReplicatedJob.Config] = REQUIRED pathways_xla_flags: list[str] = [] - pathways_head_cpu: Optional[str] = None - pathways_head_mem: Optional[str] = None @classmethod def define_flags(cls, fv): @@ -159,24 +155,6 @@ def define_flags(cls, fv): "Example: 'xla_tpu_x=24,megascale_y=true'", **common_kwargs, ) - flags.DEFINE_string( - "pathways_head_cpu", - None, - "CPU request for pathways-head container in cores. Default is 1 core.", - **common_kwargs, - ) - flags.DEFINE_string( - "pathways_head_mem", - None, - "Memory request for pathways-head container in GiB. Default is 16GiB", - **common_kwargs, - ) - - @classmethod - def set_defaults(cls, fv): - super().set_defaults(fv) - fv.set_default("pathways_head_cpu", fv.pathways_head_cpu or "1") - fv.set_default("pathways_head_mem", fv.pathways_head_mem or "16") @classmethod def default_config(cls): @@ -287,14 +265,10 @@ def _build_pathways_head_container(self) -> dict: head_container["env"] = env_list - cpu_req = f"{float(self.config.pathways_head_cpu) * 1000}m" - mem_req = f"{self.config.pathways_head_mem}Gi" - resources = { - "requests": {"cpu": cpu_req, "memory": mem_req}, - "limits": {"cpu": cpu_req, "memory": mem_req}, + head_container["resources"] = { + "requests": {}, + "limits": {}, } - head_container["resources"] = resources - return head_container def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: @@ -320,6 +294,9 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--server_port={_PATHWAYS_PROXY_PORT}", f"--gcs_scratch_location={staging_location}", + "--temporary_flags_for_debugging=temporary_flag_for_debugging_pipe_break_on_missing_keepalive=true", + # This should be made configurable + f"--num_elastic_slices={cfg.accelerator.num_replicas}", ] cmd_args.extend(xla_flags_from_options(self._xla_options).split()) @@ -356,6 +333,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: f"--instance_count={pathways_instance_count}", f"--instance_type={pathways_tpu_version}:{system.topology}", f"--gcs_scratch_location={staging_location}", + "--alsologtostderr", + "--temporary_flags_for_debugging=temporary_flag_for_debugging_pipe_break_on_missing_keepalive=true", ], ), ] @@ -429,6 +408,11 @@ def _build_pathways_head_job(self): annotations = _LoadBalancer( jobset_name=cfg.name, replicated_job_name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME ).metadata + annotations.update( + { + "alpha.jobset.sigs.k8s.io/exclusive-topology": "kubernetes.io/hostname", + } + ) spec = dict( parallelism=1, completions=1, @@ -451,6 +435,8 @@ def _build_pathways_worker_container( container = self._inner._build_container() worker_container = copy.deepcopy(container) + worker_container["name"] = "pathways-worker" + env_list = worker_container.get("env", []) pathways_head_address = self._get_pathways_head_address( @@ -503,6 +489,7 @@ def _build_pathways_worker_container( f"--resource_manager_address={pathways_head_address}:" + f"{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--gcs_scratch_location={cfg.output_dir}/pathways-staging", + "--temporary_flags_for_debugging=temporary_flag_for_debugging_pipe_break_on_missing_keepalive=true", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) @@ -581,14 +568,19 @@ def _build_pathways_worker_job( annotations.update( {"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"} ) + # Default value for suspend and resume. + # References: + # https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651 + # backoffLimit = system.vms_per_slice * 4 + + # This backoffLimit is just for verifying elastic fast-resume + large_number = 1000 + backoffLimit = system.vms_per_slice * 4 * large_number spec = dict( parallelism=system.vms_per_slice, completions=system.vms_per_slice, - # Default value for suspend and resume. - # References: - # https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651 - backoffLimit=system.vms_per_slice * 4, + backoffLimit=backoffLimit, template=self._build_pathways_worker_pod(pathways_worker_replicated_job_index), ) worker_job = dict( @@ -608,7 +600,7 @@ def __call__(self) -> Sequence[Nested[Any]]: ), dict( name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, - replicas=cfg.accelerator.num_replicas, + replicas=cfg.accelerator.num_replicas + 1, template=self._build_pathways_worker_job(), ), ] diff --git a/axlearn/common/launch_trainer.py b/axlearn/common/launch_trainer.py index bba28533e..1aed98b06 100644 --- a/axlearn/common/launch_trainer.py +++ b/axlearn/common/launch_trainer.py @@ -2,7 +2,9 @@ """Utilities to launch a trainer.""" +import contextlib import json +import logging as py_logging import os from typing import Any, Optional @@ -15,6 +17,8 @@ from axlearn.common.utils import MeshShape, get_data_dir, infer_mesh_shape from axlearn.experiments import TrainerConfigFn, get_named_trainer_config +py_logging.raiseException = True + # Trainer-specific flags. flags.DEFINE_string( "module", @@ -128,8 +132,8 @@ def get_trainer_config( return trainer_config -def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: - measurement.record_event(measurement.Event.START_JOB) +def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any: + """Instantiates and runs the trainer.""" trainer_config_debug_string = trainer_config.debug_string() logging.info("Trainer config:\n%s", trainer_config_debug_string) if jax.process_index() == 0: @@ -147,8 +151,52 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: f, ) - trainer: SpmdTrainer = trainer_config.instantiate(parent=None) - prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) - output = trainer.run(prng_key) - measurement.record_event(measurement.Event.END_JOB) + if FLAGS.jax_backend == "proxy": + # pylint: disable-next=import-error,import-outside-toplevel + from pathwaysutils.elastic import manager + elastic_manager = manager.Manager() + max_attempts = 5 + for attempt_index in range(max_attempts): + try: + logging.info(f"Elastic attempt {attempt_index + 1}/{max_attempts}") + + timeout = 10 * 60 # ten minutes + logging.info(f"Waiting up to {timeout} s for slices to be ready") + elastic_manager.wait_for_slices(timeout=timeout) + + trainer: SpmdTrainer = trainer_config.instantiate(parent=None) + prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) + output = trainer.run(prng_key) + break + except jax.errors.JaxRuntimeError as error: + if not elastic_manager.is_error_due_to_slice_down(error): + raise + try: + logging.info("Trying to clean up ongoing traces") + jax.profiler.stop_trace() + logging.info("Successfully cleaned up ongoing traces") + except (RuntimeError, ValueError) as e: + logging.info("No ongoing traces to clean up") + except Exception as e: + logging.exception("Error trying to clean up ongoing traces") + raise + + jax.clear_caches() + for array in jax.live_arrays(): + array.delete() + + else: + trainer: SpmdTrainer = trainer_config.instantiate(parent=None) + prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) + output = trainer.run(prng_key) + return output + + +def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: + recorder = measurement.global_recorder + job_events_manager = ( + recorder.record_event(measurement.EventType.JOB) if recorder else contextlib.nullcontext() + ) + with job_events_manager: + return _run_trainer_impl(trainer_config) diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 2f617b4cd..8d170a950 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -13,7 +13,6 @@ def main(_): launch.setup() trainer_config = launch_trainer.get_trainer_config() trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) - measurement.start_monitoring() launch_trainer.run_trainer(trainer_config) diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index b0a40a85f..2ce3796bc 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -2,6 +2,7 @@ """A library to measure e2e metrics like goodput.""" +import contextlib import enum import importlib from typing import Optional, TypeVar @@ -41,6 +42,26 @@ class Event(enum.Enum): END_CUSTOM_BADPUT_EVENT = "END_CUSTOM_BADPUT_EVENT" +class EventType(enum.Enum): + """Event to be recorded. + + Attributes: + JOB: Start and end of the job. + STEP: Start of a training step. Should be recorded with `step` as a positional arg. + ACCELERATOR_INIT: Start and end of accelerator mesh initialization. + TRAINING_PREPARATION: Start and end of training preparation. + DATA_LOADING: Start and end of data loading. + CUSTOM_BADPUT_EVENT: Start and end of custom badput events. + """ + + JOB = "job" + STEP = "step" + ACCELERATOR_INIT = "tpu_init" + TRAINING_PREPARATION = "training_preparation" + DATA_LOADING = "data_loading" + CUSTOM_BADPUT_EVENT = "custom_badput_event" + + class Recorder(Configurable): """The base interface for collecting e2e metrics.""" @@ -67,6 +88,20 @@ def start_monitoring(self, **kwargs): """Starts computing and uploading metrics at some configured interval in the background.""" raise NotImplementedError(type(self)) + @contextlib.contextmanager + def record_event(self, event: Event, *args, **kwargs): + """A context manager to record the start and end of an event.""" + # pylint: disable=unnecessary-pass + # pylint: disable=unused-argument + try: + yield + finally: + pass + + @contextlib.contextmanager + def maybe_monitor_all_goodput(self): + yield + _recorders: dict[str, type] = {} _T = TypeVar("_T") diff --git a/axlearn/common/measurement_test.py b/axlearn/common/measurement_test.py index c9043f20b..43d7345ed 100644 --- a/axlearn/common/measurement_test.py +++ b/axlearn/common/measurement_test.py @@ -92,3 +92,7 @@ def test_initialize(self, recorder_type, expected): ) as mock_start_monitoring: measurement.start_monitoring() mock_start_monitoring.assert_called_once() + + # Ensure that maybe_monitor_all_goodput does not fail (just enter and exit context). + with measurement.global_recorder.maybe_monitor_all_goodput(): + pass diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 0603f7bf9..89656bb96 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -241,116 +241,121 @@ def __init__( self._device_monitor = maybe_instantiate(cfg.device_monitor) self._recorder = maybe_instantiate(cfg.recorder) self._is_initialized: bool = False - self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) + # Accelerator initialization. + with self._record_event(measurement.EventType.ACCELERATOR_INIT): + if cfg.model.dtype is None: + raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") + if cfg.model.param_init is None: + cfg.model.param_init = DefaultInitializer.default_config() + logging.info( + "model.param_init is not specified. Default to DefaultInitializer: %s", + cfg.model.param_init, + ) - if cfg.model.dtype is None: - raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") - if cfg.model.param_init is None: - cfg.model.param_init = DefaultInitializer.default_config() - logging.info( - "model.param_init is not specified. Default to DefaultInitializer: %s", - cfg.model.param_init, + self._per_param_train_dtype = maybe_instantiate( + canonicalize_per_param_dtype(cfg.train_dtype) ) - self._per_param_train_dtype = maybe_instantiate( - canonicalize_per_param_dtype(cfg.train_dtype) - ) - - # Create the device mesh. - if devices is None: - self._step_log( - "Devices: global=%s local=%s %s", - jax.device_count(), - jax.local_device_count(), - [device.platform for device in jax.local_devices()], - ) - else: - local_devices = [d for d in devices.flatten() if d.process_index == jax.process_index()] - self._step_log( - "Devices: global=%s local=%s %s", - len(devices), - len(local_devices), - [device.platform for device in local_devices], - ) - self._step_log("Mesh shape: %s", cfg.mesh_shape) - devices = ( - utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices - ) - mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) - self._step_log("Global mesh: %s", mesh) - self._mesh = mesh - self._context_manager: Callable[[], ContextManager] = ( - maybe_instantiate(cfg.context_manager) or contextlib.nullcontext - ) - xsc_check_policy = None - if cfg.xsc_check_policy: - if jax.default_backend() != "tpu": - # XSC is currently only supported on TPU XLA backend. - logging.warning( - "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + # Create the device mesh. + if devices is None: + self._step_log( + "Devices: global=%s local=%s %s", + jax.device_count(), + jax.local_device_count(), + [device.platform for device in jax.local_devices()], ) else: - xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) - self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy - self._compiled_train_step: Optional[jax.stages.Compiled] = None - - # Create all children within the mesh context so that utils.input_partition_spec() works - # properly. - with self.mesh(): - self.input: Input = self._add_child( - "input", - maybe_set_config( - cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names), is_training=True - ), - ) - # Start from the beginning of the input dataset by default. - self._input_iter = iter(self.input.dataset()) - cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", "train_train" - ) - self._add_child("summary_writer", cfg.summary_writer) - self._add_child("model", cfg.model) - self._add_child("learner", cfg.learner) - cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") - self._add_child("checkpointer", cfg.checkpointer) - if cfg.init_state_builder is not None: - self._add_child("init_state_builder", cfg.init_state_builder) - - self._model_param_specs = self.model.create_parameter_specs_recursively() - model_param_partition_specs = jax.tree.map( - lambda spec: spec.mesh_axes, self._model_param_specs - ) - for name, spec in utils.flatten_items(self._model_param_specs): - self._step_log("Model param spec: %s=%s", name, spec) - self._learner_state_partition_specs = self.learner.create_state_partition_specs( - self._model_param_specs - ) - for name, spec in utils.flatten_items(self._learner_state_partition_specs): - self._step_log("Learner state spec: %s=%s", name, spec) - self._trainer_state_specs = TrainerState( - prng_key=ParameterSpec(dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None)), - model=self._model_param_specs, - learner=self._learner_state_partition_specs, + local_devices = [ + d for d in devices.flatten() if d.process_index == jax.process_index() + ] + self._step_log( + "Devices: global=%s local=%s %s", + len(devices), + len(local_devices), + [device.platform for device in local_devices], + ) + self._step_log("Mesh shape: %s", cfg.mesh_shape) + devices = ( + utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices ) - self._trainer_state_partition_specs: TrainerState = jax.tree.map( - lambda spec: spec.sharding, self._trainer_state_specs + mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) + self._step_log("Global mesh: %s", mesh) + self._mesh = mesh + self._context_manager: Callable[[], ContextManager] = ( + maybe_instantiate(cfg.context_manager) or contextlib.nullcontext ) - # Create evalers, which depend on model_param_partition_specs. - self._evalers = {} - for evaler_name, evaler_cfg in cfg.evalers.items(): - evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", evaler_name + xsc_check_policy = None + if cfg.xsc_check_policy: + if jax.default_backend() != "tpu": + # XSC is currently only supported on TPU XLA backend. + logging.warning( + "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + ) + else: + xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) + self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy + self._compiled_train_step: Optional[jax.stages.Compiled] = None + + # Create all children within the mesh context so that utils.input_partition_spec() works + # properly. + with self.mesh(): + if cfg.batch_axis_names is not None: + cfg.input = maybe_set_config( + cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + ) + self.input: Input = self._add_child( + "input", maybe_set_config(cfg.input, is_training=True) + ) + # Start from the beginning of the input dataset by default. + self._input_iter = iter(self.input.dataset()) + cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", "train_train" ) - maybe_set_config( - evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + self._add_child("summary_writer", cfg.summary_writer) + self._add_child("model", cfg.model) + self._add_child("learner", cfg.learner) + cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") + self._add_child("checkpointer", cfg.checkpointer) + if cfg.init_state_builder is not None: + self._add_child("init_state_builder", cfg.init_state_builder) + + self._model_param_specs = self.model.create_parameter_specs_recursively() + model_param_partition_specs = jax.tree.map( + lambda spec: spec.mesh_axes, self._model_param_specs ) - self._evalers[evaler_name] = self._add_child( - evaler_name, - evaler_cfg, - model=self.model, - model_param_partition_specs=model_param_partition_specs, + for name, spec in utils.flatten_items(self._model_param_specs): + self._step_log("Model param spec: %s=%s", name, spec) + self._learner_state_partition_specs = self.learner.create_state_partition_specs( + self._model_param_specs ) - self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) + for name, spec in utils.flatten_items(self._learner_state_partition_specs): + self._step_log("Learner state spec: %s=%s", name, spec) + self._trainer_state_specs = TrainerState( + prng_key=ParameterSpec( + dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None) + ), + model=self._model_param_specs, + learner=self._learner_state_partition_specs, + ) + self._trainer_state_partition_specs: TrainerState = jax.tree.map( + lambda spec: spec.sharding, self._trainer_state_specs + ) + # Create evalers, which depend on model_param_partition_specs. + self._evalers = {} + for evaler_name, evaler_cfg in cfg.evalers.items(): + evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", evaler_name + ) + if cfg.batch_axis_names is not None: + maybe_set_config( + evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + ) + self._evalers[evaler_name] = self._add_child( + evaler_name, + evaler_cfg, + model=self.model, + model_param_partition_specs=model_param_partition_specs, + ) @property def step(self): @@ -368,6 +373,15 @@ def trainer_state_specs(self): def trainer_state_partition_specs(self): return self._trainer_state_partition_specs + @contextlib.contextmanager + def _record_event(self, event: measurement.EventType, *args, **kwargs): + """A helper to record an event if a recorder is configured.""" + if self._recorder: + with self._recorder.record_event(event, *args, **kwargs) as event_manager: + yield event_manager + else: + yield + def _train_step_input_partition_specs(self): # Note that subclasses may override this method to set a partition spec for pjit which is # different from that of the input partition spec. @@ -525,10 +539,6 @@ def _should_force_run_evals( ) return force_run_evals - def _maybe_record_event(self, event: measurement.Event, *args, **kwargs): - if self._recorder is not None: - self._recorder.record(event, *args, **kwargs) - # pylint: disable-next=too-many-statements,too-many-branches def run( self, prng_key: Tensor, *, return_evaler_summaries: Optional[Union[bool, set[str]]] = None @@ -554,6 +564,7 @@ def run( different types of values such as WeightedScalar, Tensor, or string, depending on the specific `metric_calculator` config of the evaler. """ + with ( ( self._device_monitor.start_monitoring() @@ -564,6 +575,7 @@ def run( self.mesh(), jax.log_compiles(self.vlog_is_on(1)), self._context_manager(), + self._recorder.maybe_monitor_all_goodput(), ): cfg = self.config # Check if need to force run evals at the last training step. @@ -572,8 +584,9 @@ def run( ) # Prepare training. - if not self._prepare_training(prng_key): - return None + with self._record_event(measurement.EventType.TRAINING_PREPARATION): + if not self._prepare_training(prng_key): + return None self._is_initialized = True @@ -586,10 +599,10 @@ def run( input_iterator = self.input.batches(self._input_iter) while True: - self._maybe_record_event(measurement.Event.START_DATA_LOADING) try: - input_batch = next(input_iterator) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) + with self._record_event(measurement.EventType.DATA_LOADING): + input_batch = next(input_iterator) + logging.log_first_n( logging.INFO, "input_batch=%s", 3, utils.shapes(input_batch) ) @@ -599,18 +612,18 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) - self._maybe_record_event(measurement.Event.START_STEP, self._step) - output = self._run_step( - utils.host_to_global_device_array( - input_batch, - partition=self._train_step_input_partition_specs(), - ), - force_run_evals=( - force_run_eval_sets_at_max_step - if self.step >= cfg.max_step - else None - ), - ) + with self._record_event(measurement.EventType.STEP, self._step): + output = self._run_step( + utils.host_to_global_array( + input_batch, + partition=self._train_step_input_partition_specs(), + ), + force_run_evals=( + force_run_eval_sets_at_max_step + if self.step >= cfg.max_step + else None + ), + ) self.vlog(3, "Done step %s", self.step) num_steps += 1 if num_steps % 100 == 0: @@ -624,9 +637,6 @@ def run( self._step_log("Reached max_step=%s. Stopping", cfg.max_step) break except StopIteration: - # Add END_DATA_LOADING event here to close the unpaired START_DATA_LOADING - # event. - self._maybe_record_event(measurement.Event.END_DATA_LOADING) break if self.step < cfg.max_step: self._step_log("Reached end of inputs. Stopping") @@ -867,7 +877,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: A boolean indicating whether the model training should start. If not, return None from the `run` function. """ - self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) cfg = self.config # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. @@ -900,7 +909,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: return False self._jit_train_step = self._pjit_train_step() - self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) return True def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int]: @@ -1041,36 +1049,29 @@ def _get_compiled_train_step_fn( mesh_shape=cfg.mesh_shape, mesh_axis_names=cfg.mesh_axis_names, device_kind=device_kind ) if not with_xsc: - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_NO_XSC", - ) - self._compiled_train_step = self.compile_train_step( - trainer_state=trainer_state, input_batch=input_batch, compiler_options=options - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.EventType.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_NO_XSC", - ) + ): + self._compiled_train_step = self.compile_train_step( + trainer_state=trainer_state, input_batch=input_batch, compiler_options=options + ) return self._compiled_train_step + logging.log_first_n(logging.INFO, "Compiling XSC train step.", 1) - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_WITH_XSC", - ) - compiled_jit_train_step_fn = self.compile_train_step( - trainer_state=trainer_state, - input_batch=input_batch, - compiler_options=options - | infer_xsc_compiler_options( - halt_on_detection=True, repeat_count=1, device_kind=device_kind - ), - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.EventType.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_WITH_XSC", - ) + ): + compiled_jit_train_step_fn = self.compile_train_step( + trainer_state=trainer_state, + input_batch=input_batch, + compiler_options=options + | infer_xsc_compiler_options( + halt_on_detection=True, repeat_count=1, device_kind=device_kind + ), + ) return compiled_jit_train_step_fn def _run_step( @@ -1127,26 +1128,23 @@ def _run_eval( force_runs: Optional[set[str]] = None, ) -> dict[str, Any]: """Runs evaluations and returns the corresponding summaries.""" - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - evaler_summaries = {} - # Note: we will use the same eval key as the training keys of the future step, - # which should be okay. - prng_key = self._trainer_state.prng_key - for evaler_name, evaler in self._evalers.items(): - prng_key, summaries, _ = evaler.eval_step( - self.step, - prng_key=prng_key, - model_params=self.model_params_for_eval(), - train_summaries=train_summaries, - force_run=bool(force_runs is not None and evaler_name in force_runs), - ) - evaler_summaries[evaler_name] = summaries - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - return evaler_summaries + with self._record_event( + measurement.EventType.CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" + ): + evaler_summaries = {} + # Note: we will use the same eval key as the training keys of the future step, + # which should be okay. + prng_key = self._trainer_state.prng_key + for evaler_name, evaler in self._evalers.items(): + prng_key, summaries, _ = evaler.eval_step( + self.step, + prng_key=prng_key, + model_params=self.model_params_for_eval(), + train_summaries=train_summaries, + force_run=bool(force_runs is not None and evaler_name in force_runs), + ) + evaler_summaries[evaler_name] = summaries + return evaler_summaries def _pjit_train_step(self) -> jax.stages.Wrapped: return pjit( diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 57d606dab..7539bad5d 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -504,7 +504,8 @@ def mixture_train_input_source( config_for_function(input_tf_data.tfds_dataset).set( dataset_name=component.name, split=component.split, - train_shuffle_buffer_size=64 * component.shuffle_buffer_size, + # train_shuffle_buffer_size=64 * component.shuffle_buffer_size, + train_shuffle_buffer_size=0, read_config=tfds_read_config(), ) ) @@ -643,6 +644,7 @@ def get_trainer_config_fn( keep_every_n_steps: int = 50_000, save_every_n_steps: Optional[int] = None, init_state_builder: Optional[state_builder.Builder.Config] = None, + checkpointer: str = "", ) -> TrainerConfigFn: """Builds a TrainerConfigFn according to the model and input specs. @@ -710,12 +712,56 @@ def config_fn() -> InstantiableConfig: ) cfg.evalers[name] = evaler_cfg # Summaries and checkpoints. - cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( - n=save_every_n_steps or min(eval_every_n_steps, 5_000), - max_step=max_step, - ) - cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) - cfg.checkpointer.keep_last_n = 3 + # cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( + # n=save_every_n_steps or min(eval_every_n_steps, 5_000), + # max_step=max_step, + # ) + calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 100) + + if not checkpointer: + cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + # cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + cfg.checkpointer.keep_last_n = 3 + elif checkpointer == "OrbaxEmergencyCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer + + ckpt_config: OrbaxEmergencyCheckpointer.Config = ( + OrbaxEmergencyCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.local_dir = "/host-tmp/checkpoints" + # ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + ckpt_config.replica_axis_index = 1 + cfg.checkpointer = ckpt_config + elif checkpointer == "OrbaxRegularCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax import OrbaxCheckpointer + + ckpt_config: OrbaxCheckpointer.Config = OrbaxCheckpointer.default_config() + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + # ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + cfg.checkpointer = ckpt_config + + # cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + # cfg.checkpointer.keep_last_n = 3 cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..c5f73f213 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -15,6 +15,7 @@ import itertools from typing import Any, List, NamedTuple, Optional, Union +import jax from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from axlearn.common import causal_lm, config @@ -67,7 +68,7 @@ from axlearn.experiments.text.gpt.common import scaled_hidden_dim from axlearn.experiments.trainer_config_utils import TrainerConfigFn, V6eFlashConfigModifier -MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B") +MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "150B") class Version(enum.Enum): @@ -113,6 +114,7 @@ class Version(enum.Enum): "test": 2 * (1024**4), # 2T tokens "7B": 2 * (1024**4), # 2T tokens "70B": 2 * (1024**4), # 2T tokens + "150B": 2 * (1024**4), # 2T tokens }, Version.V3: { "test": 15 * (1024**4), # 15T tokens @@ -120,6 +122,7 @@ class Version(enum.Enum): "3B": 15 * (1024**4), # 15T tokens "7B": 15 * (1024**4), # 15T tokens "70B": 15 * (1024**4), # 15T tokens + "150B": 2 * (1024**4), # 2T tokens }, Version.V3_TIKTOKEN: { "test": 15 * (1024**4), # 15T tokens @@ -127,6 +130,7 @@ class Version(enum.Enum): "3B": 15 * (1024**4), # 15T tokens "8B": 15 * (1024**4), # 15T tokens "70B": 15 * (1024**4), # 15T tokens + "150B": 2 * (1024**4), # 2T tokens }, } @@ -378,8 +382,9 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=len(jax.devices()), max_step=max_step, + save_every_n_steps=100, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( # Step time: @@ -392,6 +397,24 @@ def get_trainer_kwargs( # tpu-v4-(1024|2048). ("tpu-v4-(1024|2048)", mesh_shape_from_axes(data=-1, fsdp=16)), # tpu-v5e. + ( + "tpu-v5litepod-32-2", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=32) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=False, + policy=offload_dots_saveable_policy, + ), + } + ), + ], + ), + ), ( "tpu-v5litepod-256", ChainConfigModifier.default_config().set( @@ -809,6 +832,55 @@ def get_trainer_kwargs( ), ), ) + elif model_size == "150B": + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=80, + hidden_dim=128 * 96, + num_heads=96, + # No GQA support in V1 models, so num_kv_heads is the same as num_heads. + num_kv_heads=None if version == Version.V1 else 8, + ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=256), + rope_theta=rope_theta, + shared_lm_head=False, + flash_attention=flash_attention, + ), + learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), + max_sequence_length=max_sequence_length, + train_batch_size=train_batch_size, # number of devices times 4 chips per device times 4096 samples per chip # train_batch_size, + max_step=10_000, # max_step, + save_every_n_steps=20_000_000, + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4), + mesh_rules=( + ( + # Target per-device token count = 4k. + # PDBS = 0.5 at 8k context. + # Each slice can train a batch size of 128. + "tpu-v6e-256.*", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=False, + policy=save_and_offload_only_these_names_regex( + names_which_can_be_offloaded=".*input", + names_which_can_be_saved=None, + offload_src="device", + offload_dst="pinned_host", + ), + ), + } + ), + V6eFlashConfigModifier.default_config(), + ], + ), + ), + ), + ) else: raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") @@ -914,17 +986,25 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + for version, model_size, flash_attention, checkpointer in itertools.product( + Version, MODEL_SIZES, [True, False], ["", "OrbaxEmergencyCheckpointer", "OrbaxRegularCheckpointer"], ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + current_suffix_parts = [] + if flash_attention: + current_suffix_parts.append("-flash") + if checkpointer == "OrbaxEmergencyCheckpointer": + current_suffix_parts.append("-orbaxem") + elif checkpointer == "OrbaxRegularCheckpointer": + current_suffix_parts.append("-orbax") + current_suffix = "".join(current_suffix_parts) config_name = make_config_name( arch=arch, model_size=model_size, version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + suffix=current_suffix, ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention @@ -939,6 +1019,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), + checkpointer=checkpointer, **kwargs, ) diff --git a/docs/05-Goodput-Monitoring.md b/docs/05-Goodput-Monitoring.md index ca1452c19..cb17f6989 100644 --- a/docs/05-Goodput-Monitoring.md +++ b/docs/05-Goodput-Monitoring.md @@ -1,10 +1,14 @@ # ML Goodput Monitoring -AXLearn supports automatic measurement and upload of workload metrics such as -Goodput, Badput Breakdown and Step Time Deviation using the ML Goodput -Measurement library. +AXLearn supports automatic measurement and upload of a wide range of workload +metrics using the **ML Goodput Measurement** library. This includes: +* **Goodput** and **Badput Breakdown** +* **Step Metrics** (Ideal Step Time, Step Time Deviation, Last Productive Step etc.) +* **Workload Hang Metrics** (Disruption Count, Step Info) +* **Rolling Window Goodput & Badput Breakdown** The [ML Goodput Measurement](https://github.com/AI-Hypercomputer/ml-goodput-measurement) library currently supports monitoring workloads running on Google Cloud Platform. For more information on details of the library, visit the Github page or the [ml-goodput-measurement](https://pypi.org/project/ml-goodput-measurement/) PyPI package documentation. + ### What is Goodput Goodput is the metric that measures the efficiency of model training jobs, i.e. productive time spent on training progress proportional to the total time spent @@ -15,12 +19,26 @@ improve to get the most value from their accelerators. Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, -program startup, data loading, portions of checkpointing, disruptions and -wasted progress since the last checkpoint etc. all contribute to Badput. +program startup, data loading, portions of checkpointing, recovering from +disruptions, wasted progress since the last checkpoint etc. all contribute to Badput. + +The ML Goodput Measurement library exposes Badput Breakdown. Further details of +each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) + +## What is Rolling Window Goodput & Badput +The ML Goodput Measurement library allows users to monitor goodput and badput +breakdown metrics within specific, moving time windows. You can specify a list +of rolling window interval sizes in seconds, and the library will asynchronously +query and upload metrics calculated only within the context of those windows. +This is useful for understanding workload performance over recent, specific +durations (e.g., the last 24 hours). -The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) +If the workload's actual runtime timeline is shorter than a requested window size, +the entire runtime timeline of the workload is used for the metrics computation. -### What is Step Time Deviation +> **Note**: Both the standard (cumulative) and rolling window query APIs can be enabled simultaneously to get a complete picture of your workload's performance. + +### What are Ideal Step Time and Step Time Deviation Step Time Deviation is the metric that measures deviation of step time (in seconds) from ideal step time. It is the difference between the actual time @@ -33,8 +51,8 @@ The formula for step deviation is: Ideal step time is equal to the user-configured `ideal_step_time` if it is provided. If the user has not specified an ideal step time, then the ideal step -time is calculated as the average of the "normal" step times recorded for the -workload, where a "normal" step is defined as having a duration less than or +time is calculated as a weighted average of the "normal" step times recorded for +the workload, where a "normal" step is defined as having a duration less than or equal to `median + median absolute deviation * 3` of the sample space of step times. This computation requires at least 10 recorded steps. @@ -77,7 +95,7 @@ project, then do the following: Please use a unique workload name, unless you intend to monitor cumulative Goodput/Badput metrics of a previous workload along with your current workload. -### How to Monitor Goodput and Badput +### How to Monitor Cumulative Goodput Metrics To enable Goodput recording and monitoring on AXLearn, follow the example below. @@ -94,24 +112,22 @@ To enable Goodput recording and monitoring on AXLearn, follow the example below. --recorder_spec=upload_interval=30 \ ``` -### How to Monitor Step Time Deviation +### How to Monitor Rolling Window Goodput Metrics -AXLearn enables step time deviation monitoring by default. You can configure -the upload frequency by setting -`--recorder_spec=step_deviation_interval_seconds=30`. To disable step deviation -set `--recorder_spec=step_deviation_interval_seconds=-1`. +To enable rolling window metrics, set `enable_rolling_window_goodput_monitoring` to `True` +and provide a list of interval sizes for `rolling_window_size` in seconds: ```bash - axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ +axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ --bundler_type=artifactregistry --bundler_spec=image=tpu \ --bundler_spec=dockerfile=Dockerfile \ - --name= \ - -- python3 -m ...training-config... \ + -- python3 -m my_training_job \ --recorder_type=axlearn.cloud.gcp.measurement:goodput \ --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 \ + --recorder_spec=enable_rolling_window_goodput_monitoring=True \ + --recorder_spec=rolling_window_size=86400,259200,432000 ``` ### Visualize on Tensorboard @@ -121,12 +137,16 @@ set `--recorder_spec=step_deviation_interval_seconds=-1`. ### Enabling Google Cloud Monitoring -AXLearn has an additional option of pushing goodput, badput and step time -deviation metrics to Google Cloud Monitoring. By default if goodput monitoring -is enabled, the data gets published to Google Cloud Monitoring. Set the variables -`enable_gcp_goodput_metrics` and `enable_gcp_step_deviation_metrics` to `False` in -`goodput_monitoring.GCPOptions` in `cloud/gcp/measurement.py` to disable goodput and step_deviation -uploads to GCM respectively. +By default, when Goodput monitoring is enabled via the recorder, AXLearn automatically pushes metrics to Google Cloud Monitoring. + +- **Cumulative Metrics** are enabled by default when you specify the `recorder_type`. + To disable this, you would need to set `enable_gcp_goodput_metrics` to `False` in + `goodput_monitoring.GCPOptions` within the `cloud/gcp/measurement.py` file. +- **Rolling Window Metrics** can be explicitly enabled by setting + `enable_rolling_window_goodput_monitoring` to `True` and providing window sizes + via `rolling_window_size`. + +You can enable either cumulative monitoring, rolling window monitoring, or both simultaneously. ```bash axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ @@ -138,7 +158,8 @@ uploads to GCM respectively. --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 \ + --recorder_spec=enable_rolling_window_goodput_monitoring=True \ + --recorder_spec=rolling_window_size=86400,604800 ``` #### Visualization in Google Cloud Monitoring @@ -159,3 +180,38 @@ To visualize the collected metrics within Google Cloud Monitoring: c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) Represents the workload's performance metric, specifically step deviation in this context, measured by `compute.googleapis.com/workload/performance`. + +#### Google Cloud Monitoring Dashboard: Goodput Monitor + +Following are instructions for deploying a custom dashboard `goodput_dashboard.json` +to your Google Cloud project's Monitoring console. This dashboard +offers a comprehensive view of "Goodput" metrics, helping you monitor the +your workloads and set up custom alerts for "events" such as performance degradation. + + +#### Deployment Steps + +Follow these steps to create a new custom dashboard using the provided JSON +configuration: + +1. **Navigate to the Monitoring Console**: In your Google Cloud project, + go to the **Monitoring** section. From the left-hand navigation menu, + select **Dashboards**. + +2. **Create Custom Dashboard**: Click the **Create Custom Dashboard** button. + +3. **Use JSON Editor**: In the new dashboard interface, select the + **JSON editor** option. + +4. **Copy and Save Configuration**: Open the [goodput_dashboard.json](https://github.com/AI-Hypercomputer/ml-goodput-measurement/blob/main/ml_goodput_measurement/dashboards/goodput_dashboard.json) file. + Copy its entire content and paste it into the JSON editor. Once pasted, + click **Save**. + + +Your "Goodput Monitor" dashboard should now be visible and operational within +your custom dashboards list. + +> **_NOTE:_** This dashboard is intended to be a starting point for your +> monitoring needs. We recommend customizing it to meet your specific needs. +> Please refer to the [Monitoring Dashboard documentation](https://cloud.google.com/monitoring/dashboards) +> for further guidance and customization options. diff --git a/patches/shard_map.py.patch b/patches/shard_map.py.patch new file mode 100644 index 000000000..e6e10104f --- /dev/null +++ b/patches/shard_map.py.patch @@ -0,0 +1,27 @@ +--- shard_map_orig.py 2025-06-18 01:27:00.782665547 +0000 ++++ shard_map.py 2025-06-18 01:26:06.798346281 +0000 +@@ -1793,10 +1793,10 @@ + ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], + list[core.Var]]: + jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] +- auto = eqn.params['auto'] +- with _extend_axis_env(mesh, auto): ++ manual_axes = frozenset() ++ with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): + jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ +- pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) ++ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) + num_out_primals = len(jaxpr_known.outvars) - num_res + in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] + out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) +@@ -1804,8 +1804,8 @@ + out_fwd = [idx_map.get(id(v)) for v in res_vars] + which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + mesh = eqn.params['mesh'] +- with (_extend_axis_env(mesh, auto), +- use_abstract_mesh(_as_manual_mesh(mesh, auto))): ++ with (_extend_axis_env(mesh, manual_axes), ++ use_abstract_mesh(_as_manual_mesh(mesh, frozenset()))): + jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) + jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) + jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) diff --git a/pyproject.toml b/pyproject.toml index 57e415e7e..622bb7aa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ gcp = [ "google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access. "google-cloud-core==2.3.3", "google-cloud-build==3.24.1", - "ml-goodput-measurement==0.0.10", + "ml-goodput-measurement==0.0.14", "pika==1.3.2", # used by event queue "pyOpenSSL>=22.1.0", # compat with cryptography version. "tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info @@ -109,7 +109,12 @@ gcp = [ tpu = [ "axlearn[gcp]", "jax[tpu]==0.5.3", # must be >=0.4.19 for compat with v5p. - "pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator. +] +# For Pathways-TPU single-controller training +pathways-tpu = [ + "axlearn[gcp]", + "jax==0.5.3", # must be >=0.4.19 for compat with v5p. + "pathwaysutils @ git+https://github.com/AI-Hypercomputer/pathways-utils", ] # Vertex AI tensorboard. TODO(markblee): Merge with `gcp`. vertexai_tensorboard = [