Skip to content

Commit 16237d5

Browse files
committed
Run local_sgd/diloco in titan
1 parent e0d2de6 commit 16237d5

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

torchtitan/components/ft.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
import copy
88
import importlib
9+
from contextlib import nullcontext
910
from dataclasses import dataclass
10-
from typing import Optional
11+
from typing import ContextManager, Optional, TYPE_CHECKING, Union
1112

1213
import torch
1314
import torch.distributed as dist
@@ -22,6 +23,9 @@
2223
if importlib.util.find_spec("torchft") is not None:
2324
import torchft as ft
2425

26+
if TYPE_CHECKING:
27+
from torchft import local_sgd
28+
2529
has_torchft = True
2630
else:
2731
has_torchft = False
@@ -85,13 +89,16 @@ def init_ft_manager(job: JobConfig) -> FTManager:
8589

8690
pg = ft.ProcessGroupNCCL()
8791

92+
# If the training method is specific, then the quorum should be synchronous
93+
use_async_quorum = job.fault_tolerance.semi_sync_method is None
94+
8895
return FTManager(
8996
ft.Manager(
9097
pg=pg,
9198
min_replica_size=job.fault_tolerance.min_replica_size,
9299
load_state_dict=None,
93100
state_dict=None,
94-
use_async_quorum=True,
101+
use_async_quorum=use_async_quorum,
95102
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
96103
),
97104
group_size=job.fault_tolerance.group_size,
@@ -158,3 +165,50 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
158165
return DTensor.from_local(local_tensor, mesh.mesh, placements)
159166

160167
return total_norm
168+
169+
170+
def maybe_semi_sync_training(
171+
config: JobConfig,
172+
ft_manager: FTManager,
173+
model: torch.nn.Module,
174+
optimizer: torch.optim.Optimizer,
175+
sync_every: int,
176+
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
177+
"""
178+
If TorchFT is enabled and the config is set, use semi_sync_method
179+
"""
180+
semi_sync_method = config.fault_tolerance.semi_sync_method
181+
torchft_enabled = config.fault_tolerance.enable
182+
if torchft_enabled and semi_sync_method is not None:
183+
from torchft import local_sgd
184+
185+
assert (
186+
ft_manager._manager is not None
187+
), "FTManager must be enabled to use semi-sync training."
188+
if semi_sync_method.lower() == "diloco":
189+
# Create the outer optimizer based on the inner optimizer parameters.
190+
params = [group["params"] for group in optimizer.param_groups]
191+
params = [param for sublist in params for param in sublist]
192+
outer_optimizer = torch.optim.SGD(
193+
params, lr=0.7, momentum=0.9, nesterov=True
194+
)
195+
196+
return local_sgd.DiLoCo(
197+
manager=ft_manager._manager,
198+
model=model,
199+
inner_optimizer=optimizer,
200+
outer_optimizer=outer_optimizer,
201+
sync_every=sync_every,
202+
)
203+
elif semi_sync_method.lower() == "local_sgd":
204+
return local_sgd.LocalSGD(
205+
manager=ft_manager._manager,
206+
model=model,
207+
optimizer=optimizer,
208+
sync_every=sync_every,
209+
)
210+
else:
211+
raise ValueError(
212+
f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported."
213+
)
214+
return nullcontext()

torchtitan/config_manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,19 @@ class FaultTolerance:
502502
min_replica_size: int = 1
503503
"""The minimum number of FT replica for each step."""
504504

505+
semi_sync_method: str | None = None
506+
"""
507+
The algorithm to use for semi-sync training. Currently, only "local_sgd" and "diloco" from
508+
torchft are supported
509+
(https://github.com/pytorch/torchft/blob/360c5c534bdeac959507e9d238ba9f3902d3fda9/torchft/local_sgd.py#L41)
510+
"""
511+
512+
sync_steps: int = 5
513+
"""
514+
Number of steps to wait before performing synchronization. This is only used when "semi_sync_method"
515+
is set.
516+
"""
517+
505518

506519
@dataclass
507520
class Experimental:

torchtitan/train.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,13 @@ def train(self):
399399
job_config, global_step=self.step
400400
) as torch_profiler, maybe_enable_memory_snapshot(
401401
job_config, global_step=self.step
402-
) as memory_profiler:
402+
) as memory_profiler, ft.maybe_semi_sync_training(
403+
job_config,
404+
ft_manager=self.ft_manager,
405+
model=self.model_parts[0],
406+
optimizer=self.optimizers,
407+
sync_every=job_config.fault_tolerance.sync_steps,
408+
):
403409
data_iterator = iter(self.dataloader)
404410
while self.step < job_config.training.steps:
405411
self.step += 1

0 commit comments

Comments
 (0)