Skip to content

[FT] Support local_sgd / diloco in titan #1122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged

Conversation

H-Huang
Copy link
Member

@H-Huang H-Huang commented Apr 18, 2025

Depends on torchft changes:

This PR adds a new semi sync method context manager which wraps around the train loop to run local sgd or diloco. It also adds multiple config properties to set and control the training method.

To run (need 3 different terminals):

Start torchft lighthouse (terminal 1):
RUST_LOGS=debug RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000

Start replica 1 (terminal 2, update lighthouse URL):
TORCHFT_LIGHTHOUSE=<url> TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco"

Start replica 2 (terminal 3, update lighthouse URL):
TORCHFT_LIGHTHOUSE=<url> TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco"

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 18, 2025
@H-Huang H-Huang changed the title [FT] Support local_sgd / diloco in titan [WIP] [FT] Support local_sgd / diloco in titan Apr 18, 2025
@H-Huang H-Huang force-pushed the diloco branch 2 times, most recently from 7d1a96d to 3c0a9f8 Compare April 22, 2025 17:33
@H-Huang H-Huang changed the title [WIP] [FT] Support local_sgd / diloco in titan [FT] Support local_sgd / diloco in titan Apr 22, 2025
@H-Huang H-Huang marked this pull request as ready for review April 22, 2025 17:47
@H-Huang H-Huang requested review from fegin and d4l3k April 22, 2025 17:56
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we can share the training loop? We basically just add another context manager. Additionally, the training_method is added to the main JobConfig. so we shouldn't create an additional train.py.

If we don't want to expose this feature as the main feature yet, then we have to move training_method configuration to experiments/fault_tolerance/train.py as well. You can check torchtitan/tests/unit_tests/test_job_config.py to find out the example.

model=self.model_parts[0],
optimizer=self.optimizers,
sync_every=2,
) as semi_sync_training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we are not using this variable. We don't need as semi_sync_training?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we dont need it. I can remove

@@ -502,6 +502,13 @@ class FaultTolerance:
min_replica_size: int = 1
"""The minimum number of FT replica for each step."""

training_method: str | None = "diloco"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

semi_sync_method or synchronize_method are less confusing. trianing_method is pretty general. And can we specify that if the value is not set, we will use synchronized training?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

semi_sync_method sounds good. I will also update that comment

@H-Huang
Copy link
Member Author

H-Huang commented Apr 22, 2025

Correct @fegin, we can share the training loop. My main worry was making the regular training loop harder to read by adding this, but yeah you are right that the configs are already in JobConfig.

The alternative is to ask the user to use torchtitan.experiments.fault_tolerance import FtTrainer and replace their trainer = FtTrainer(config), but im not sure if we currently support swapping in another trainer. Maybe your proposal is the right way to go

@d4l3k
Copy link
Member

d4l3k commented Apr 22, 2025

I think if we can do it with the same train loop it'd be nice -- maybe we can use ExitStack for cleaner optional registration of the context managers?

@fegin
Copy link
Contributor

fegin commented Apr 22, 2025

I'm okay either way. Since TorchFT is already in the core TorchTitan, I think it is okay to put the semi-sync to the main training loop.

If you want to put it in the experiment folder, custom_args_module is all you need.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one small request change if makes sense.

model=self.model_parts[0],
optimizer=self.optimizers,
sync_every=job_config.fault_tolerance.sync_steps,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I just realized that this context can be initialized in trainer.__init__(), putting into dist_utils.get_train_context. If my understanding is correct, that will only require to change dist_utils.get_train_context and its caller. Let me know if this makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's entering context manager for the entire training, across iterations.
dist_utils.get_train_context is a per-iteration context manager.
Which way are we supposed to use the ft contexts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting. Yeah the FT context is across iterations since it adds hooks to the optimizers and performs synchronization every N iterations (based on the optimizer .step() calls).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Maybe it'd be good to organize them into two context manager util functions, one for overall train, the other for per training iteration.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, I will add in a follow up PR!

@@ -158,3 +162,44 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
return DTensor.from_local(local_tensor, mesh.mesh, placements)

return total_norm


def maybe_semi_sync_training(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add typing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

model=self.model_parts[0],
optimizer=self.optimizers,
sync_every=job_config.fault_tolerance.sync_steps,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's entering context manager for the entire training, across iterations.
dist_utils.get_train_context is a per-iteration context manager.
Which way are we supposed to use the ft contexts?

@H-Huang H-Huang merged commit c6c28dc into pytorch:main Apr 29, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants