Skip to content

Commit c6c28dc

Browse files
authored
[FT] Support local_sgd / diloco in titan (#1122)
Depends on torchft changes: - pytorch/torchft#168 - pytorch/torchft#170 This PR adds a new semi sync method context manager which wraps around the train loop to run local sgd or diloco. It also adds multiple config properties to set and control the training method. ### To run (need 3 different terminals): Start torchft lighthouse (terminal 1): ` RUST_LOGS=debug RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000 ` Start replica 1 (terminal 2, update lighthouse URL): ` TORCHFT_LIGHTHOUSE=<url> TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco" ` Start replica 2 (terminal 3, update lighthouse URL): ` TORCHFT_LIGHTHOUSE=<url> TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco" `
1 parent f27a184 commit c6c28dc

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

torchtitan/components/ft.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
import copy
88
import importlib
9+
from contextlib import nullcontext
910
from dataclasses import dataclass
10-
from typing import Optional
11+
from typing import ContextManager, Optional, TYPE_CHECKING, Union
1112

1213
import torch
1314
import torch.distributed as dist
@@ -22,6 +23,9 @@
2223
if importlib.util.find_spec("torchft") is not None:
2324
import torchft as ft
2425

26+
if TYPE_CHECKING:
27+
from torchft import local_sgd
28+
2529
has_torchft = True
2630
else:
2731
has_torchft = False
@@ -85,13 +89,16 @@ def init_ft_manager(job: JobConfig) -> FTManager:
8589

8690
pg = ft.ProcessGroupNCCL()
8791

92+
# If the training method is specific, then the quorum should be synchronous
93+
use_async_quorum = job.fault_tolerance.semi_sync_method is None
94+
8895
return FTManager(
8996
ft.Manager(
9097
pg=pg,
9198
min_replica_size=job.fault_tolerance.min_replica_size,
9299
load_state_dict=None,
93100
state_dict=None,
94-
use_async_quorum=True,
101+
use_async_quorum=use_async_quorum,
95102
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
96103
),
97104
group_size=job.fault_tolerance.group_size,
@@ -158,3 +165,50 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
158165
return DTensor.from_local(local_tensor, mesh.mesh, placements)
159166

160167
return total_norm
168+
169+
170+
def maybe_semi_sync_training(
171+
config: JobConfig,
172+
ft_manager: FTManager,
173+
model: torch.nn.Module,
174+
optimizer: torch.optim.Optimizer,
175+
sync_every: int,
176+
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
177+
"""
178+
If TorchFT is enabled and the config is set, use semi_sync_method
179+
"""
180+
semi_sync_method = config.fault_tolerance.semi_sync_method
181+
torchft_enabled = config.fault_tolerance.enable
182+
if torchft_enabled and semi_sync_method is not None:
183+
from torchft import local_sgd
184+
185+
assert (
186+
ft_manager._manager is not None
187+
), "FTManager must be enabled to use semi-sync training."
188+
if semi_sync_method.lower() == "diloco":
189+
# Create the outer optimizer based on the inner optimizer parameters.
190+
params = [group["params"] for group in optimizer.param_groups]
191+
params = [param for sublist in params for param in sublist]
192+
outer_optimizer = torch.optim.SGD(
193+
params, lr=0.7, momentum=0.9, nesterov=True
194+
)
195+
196+
return local_sgd.DiLoCo(
197+
manager=ft_manager._manager,
198+
model=model,
199+
inner_optimizer=optimizer,
200+
outer_optimizer=outer_optimizer,
201+
sync_every=sync_every,
202+
)
203+
elif semi_sync_method.lower() == "local_sgd":
204+
return local_sgd.LocalSGD(
205+
manager=ft_manager._manager,
206+
model=model,
207+
optimizer=optimizer,
208+
sync_every=sync_every,
209+
)
210+
else:
211+
raise ValueError(
212+
f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported."
213+
)
214+
return nullcontext()

torchtitan/config_manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,19 @@ class FaultTolerance:
499499
min_replica_size: int = 1
500500
"""The minimum number of FT replica for each step."""
501501

502+
semi_sync_method: str | None = None
503+
"""
504+
The algorithm to use for semi-sync training. Currently, only "local_sgd" and "diloco" from
505+
torchft are supported
506+
(https://github.com/pytorch/torchft/blob/360c5c534bdeac959507e9d238ba9f3902d3fda9/torchft/local_sgd.py#L41)
507+
"""
508+
509+
sync_steps: int = 5
510+
"""
511+
Number of steps to wait before performing synchronization. This is only used when "semi_sync_method"
512+
is set.
513+
"""
514+
502515

503516
@dataclass
504517
class Experimental:

torchtitan/train.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,13 @@ def train(self):
401401
job_config, global_step=self.step
402402
) as torch_profiler, maybe_enable_memory_snapshot(
403403
job_config, global_step=self.step
404-
) as memory_profiler:
404+
) as memory_profiler, ft.maybe_semi_sync_training(
405+
job_config,
406+
ft_manager=self.ft_manager,
407+
model=self.model_parts[0],
408+
optimizer=self.optimizers,
409+
sync_every=job_config.fault_tolerance.sync_steps,
410+
):
405411
data_iterator = iter(self.dataloader)
406412
while self.step < job_config.training.steps:
407413
self.step += 1

0 commit comments

Comments
 (0)