Skip to content

Conversation

@xibinliu
Copy link
Collaborator

@xibinliu xibinliu commented Oct 24, 2025

Description

Integrate SDK for managed profiler

  • include new SDK google-cloud-mldiagnostics
    seed-env: --seed-commit=459cb056418de7a56c9da0a2842406a58b75e4a3
  • add new config params
  • modify profiler.py to add ML run and profiling
  • modify metrics_logger.py to upload metrics

IMPORTANT

Since the GCP UI support is not formally rolled out yet, currently this feature only works in supercomputer-testing / us-central1. Enabling this feature in other projects and regions will fail.

Tests

Command:

  1. Enable the feature with managed_profiler=True managed_profiler_run_group="<group_name>"
python3 -m MaxText.train src/MaxText/configs/base.yml run_name="xibin-run22" model_name="gpt3-52k" base_output_directory=gs://xibin-images/  dataset_type=synthetic steps=22 profiler=xplane managed_profiler=True managed_profiler_run_group="xibin-demo" log_period=5
  1. Enable the feature with managed_profiler=True, with run_group default to run_name
python3 -m MaxText.train src/MaxText/configs/base.yml run_name="xibin-run23" model_name="gpt3-52k" base_output_directory=gs://xibin-images/  dataset_type=synthetic steps=22 profiler=xplane managed_profiler=True  log_period=5
  1. Enable the feature with managed_profiler=True upload_all_profiler_results=True on all TPU devices, with run_name becomes <run_name>-0 (jax device 0), <run_name>-1 (jax device 1), etc
python3 -m MaxText.train src/MaxText/configs/base.yml run_name="xibin-run24" model_name="gpt3-52k" base_output_directory=gs://xibin-images/  dataset_type=synthetic steps=22 profiler=xplane managed_profiler=True upload_all_profiler_results=True  log_period=5

See all uploaded runs in the managed profiler GCP UI

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM just a few comments. Thanks @xibinliu, great to see this

Comment on lines 44 to 46
# Don't log the following keys.
KEYS_NO_LOGGING = ["hf_access_token"]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SamuelMarks FYI we will need this to be compatible with Pydantic in #1836. It's just a constant so should be straightforward

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what I should do for this comment. But I changed it to the tuple making it immutable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xibinliu. No action item on this one, just calling out that this will need to update in the other PR also

@xibinliu xibinliu force-pushed the xibin/diagon_sdk branch 3 times, most recently from 26d3b89 to 5176ace Compare October 25, 2025 19:29
@xibinliu xibinliu requested a review from shuningjin as a code owner October 25, 2025 19:29
@xibinliu xibinliu force-pushed the xibin/diagon_sdk branch 2 times, most recently from 68150f6 to 35b9724 Compare October 25, 2025 19:35
raise ValueError("Profiling requested but initial profiling step set past training final step")

# Set up the managed profiler on the first device, or all devices, depending on the config.
proc_id = jax.process_index()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of line 50-52, can't we just say:

if config.managed_profiler`:
  self.prof = None
  .....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Surbhi, the logic is needed because:

  1. if upload_all_profiler_results, we need do this on all TPU devicers.
  2. if not upload_all_profiler_results, we just do it on the first device.

The flag self.managed_profiler will be changed based on the above conditions.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the correct GCP project info picked up?

- include new SDK google-cloud-mldiagnostics
- add new config params
- modify profiler.py to add ML run and profiling
- modify metrics_logger.py to upload metrics
@xibinliu xibinliu changed the title Integrate Diagon SDK for managed profiler Integrate SDK for managed profiler Oct 27, 2025
@xibinliu
Copy link
Collaborator Author

How is the correct GCP project info picked up?

The ML Run (managed profiler UI) is always created under the project / regions where the workload is running

@bvandermoon
Copy link
Collaborator

How is the correct GCP project info picked up?

The ML Run (managed profiler UI) is always created under the project / regions where the workload is running

Thanks @xibinliu. In that case, the info is coming from XPK, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants