-
Couldn't load subscription status.
- Fork 418
Integrate SDK for managed profiler #2544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
06830de to
48f6fca
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM just a few comments. Thanks @xibinliu, great to see this
| # Don't log the following keys. | ||
| KEYS_NO_LOGGING = ["hf_access_token"] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SamuelMarks FYI we will need this to be compatible with Pydantic in #1836. It's just a constant so should be straightforward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what I should do for this comment. But I changed it to the tuple making it immutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @xibinliu. No action item on this one, just calling out that this will need to update in the other PR also
26d3b89 to
5176ace
Compare
68150f6 to
35b9724
Compare
| raise ValueError("Profiling requested but initial profiling step set past training final step") | ||
|
|
||
| # Set up the managed profiler on the first device, or all devices, depending on the config. | ||
| proc_id = jax.process_index() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of line 50-52, can't we just say:
if config.managed_profiler`:
self.prof = None
.....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Surbhi, the logic is needed because:
- if upload_all_profiler_results, we need do this on all TPU devicers.
- if not upload_all_profiler_results, we just do it on the first device.
The flag self.managed_profiler will be changed based on the above conditions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is the correct GCP project info picked up?
- include new SDK google-cloud-mldiagnostics - add new config params - modify profiler.py to add ML run and profiling - modify metrics_logger.py to upload metrics
35b9724 to
117e4c5
Compare
The ML Run (managed profiler UI) is always created under the project / regions where the workload is running |
Thanks @xibinliu. In that case, the info is coming from XPK, right? |
Description
Integrate SDK for managed profiler
seed-env: --seed-commit=459cb056418de7a56c9da0a2842406a58b75e4a3
IMPORTANT
Since the GCP UI support is not formally rolled out yet, currently this feature only works in supercomputer-testing / us-central1. Enabling this feature in other projects and regions will fail.
Tests
Command:
managed_profiler=True managed_profiler_run_group="<group_name>"managed_profiler=True, withrun_groupdefault torun_namemanaged_profiler=True upload_all_profiler_results=Trueon all TPU devices, withrun_namebecomes<run_name>-0(jax device 0),<run_name>-1(jax device 1), etcSee all uploaded runs in the managed profiler GCP UI
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.