-
Notifications
You must be signed in to change notification settings - Fork 353
[FT] Support local_sgd / diloco in titan #1122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7d1a96d
to
3c0a9f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we can share the training loop? We basically just add another context manager. Additionally, the training_method
is added to the main JobConfig
. so we shouldn't create an additional train.py
.
If we don't want to expose this feature as the main feature yet, then we have to move training_method
configuration to experiments/fault_tolerance/train.py
as well. You can check torchtitan/tests/unit_tests/test_job_config.py
to find out the example.
model=self.model_parts[0], | ||
optimizer=self.optimizers, | ||
sync_every=2, | ||
) as semi_sync_training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that we are not using this variable. We don't need as semi_sync_training
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we dont need it. I can remove
torchtitan/config_manager.py
Outdated
@@ -502,6 +502,13 @@ class FaultTolerance: | |||
min_replica_size: int = 1 | |||
"""The minimum number of FT replica for each step.""" | |||
|
|||
training_method: str | None = "diloco" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
semi_sync_method
or synchronize_method
are less confusing. trianing_method
is pretty general. And can we specify that if the value is not set, we will use synchronized training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
semi_sync_method
sounds good. I will also update that comment
Correct @fegin, we can share the training loop. My main worry was making the regular training loop harder to read by adding this, but yeah you are right that the configs are already in JobConfig. The alternative is to ask the user to use |
I think if we can do it with the same train loop it'd be nice -- maybe we can use ExitStack for cleaner optional registration of the context managers? |
I'm okay either way. Since TorchFT is already in the core TorchTitan, I think it is okay to put the semi-sync to the main training loop. If you want to put it in the experiment folder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, one small request change if makes sense.
model=self.model_parts[0], | ||
optimizer=self.optimizers, | ||
sync_every=job_config.fault_tolerance.sync_steps, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uh, I just realized that this context can be initialized in trainer.__init__()
, putting into dist_utils.get_train_context
. If my understanding is correct, that will only require to change dist_utils.get_train_context
and its caller. Let me know if this makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it's entering context manager for the entire training, across iterations.
dist_utils.get_train_context
is a per-iteration context manager.
Which way are we supposed to use the ft contexts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah interesting. Yeah the FT context is across iterations since it adds hooks to the optimizers and performs synchronization every N iterations (based on the optimizer .step()
calls).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Maybe it'd be good to organize them into two context manager util functions, one for overall train, the other for per training iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, I will add in a follow up PR!
@@ -158,3 +162,44 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor: | |||
return DTensor.from_local(local_tensor, mesh.mesh, placements) | |||
|
|||
return total_norm | |||
|
|||
|
|||
def maybe_semi_sync_training( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add typing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do
model=self.model_parts[0], | ||
optimizer=self.optimizers, | ||
sync_every=job_config.fault_tolerance.sync_steps, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it's entering context manager for the entire training, across iterations.
dist_utils.get_train_context
is a per-iteration context manager.
Which way are we supposed to use the ft contexts?
Depends on torchft changes:
This PR adds a new semi sync method context manager which wraps around the train loop to run local sgd or diloco. It also adds multiple config properties to set and control the training method.
To run (need 3 different terminals):
Start torchft lighthouse (terminal 1):
RUST_LOGS=debug RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000
Start replica 1 (terminal 2, update lighthouse URL):
TORCHFT_LIGHTHOUSE=<url> TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco"
Start replica 2 (terminal 3, update lighthouse URL):
TORCHFT_LIGHTHOUSE=<url> TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco"