Skip to content

Commit c6341d4

Browse files
committed
Run local_sgd/diloco in titan
1 parent e0d2de6 commit c6341d4

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

torchtitan/components/ft.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import importlib
9+
from contextlib import nullcontext
910
from dataclasses import dataclass
1011
from typing import Optional
1112

@@ -85,13 +86,16 @@ def init_ft_manager(job: JobConfig) -> FTManager:
8586

8687
pg = ft.ProcessGroupNCCL()
8788

89+
# If the training method is specific, then the quorum should be synchronous
90+
use_async_quorum = job.fault_tolerance.training_method is None
91+
8892
return FTManager(
8993
ft.Manager(
9094
pg=pg,
9195
min_replica_size=job.fault_tolerance.min_replica_size,
9296
load_state_dict=None,
9397
state_dict=None,
94-
use_async_quorum=True,
98+
use_async_quorum=use_async_quorum,
9599
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
96100
),
97101
group_size=job.fault_tolerance.group_size,
@@ -158,3 +162,44 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
158162
return DTensor.from_local(local_tensor, mesh.mesh, placements)
159163

160164
return total_norm
165+
166+
167+
def maybe_semi_sync_training(
168+
config: JobConfig,
169+
ft_manager: FTManager,
170+
model: torch.nn.Module,
171+
optimizer: torch.optim.Optimizer,
172+
sync_every: int,
173+
):
174+
"""
175+
If TorchFT is enabled and the config is set, use training_method
176+
"""
177+
training_method = config.fault_tolerance.training_method
178+
if training_method is not None:
179+
if training_method.lower() == "diloco":
180+
# Create the outer optimizer based on the inner optimizer parameters.
181+
params = [group["params"] for group in optimizer.param_groups]
182+
params = [param for sublist in params for param in sublist]
183+
outer_optimizer = torch.optim.SGD(
184+
params, lr=0.7, momentum=0.9, nesterov=True
185+
)
186+
187+
return ft.local_sgd.DiLoCo(
188+
manager=ft_manager._manager,
189+
model=model,
190+
inner_optimizer=optimizer,
191+
outer_optimizer=outer_optimizer,
192+
sync_every=sync_every,
193+
)
194+
elif training_method.lower() == "local_sgd":
195+
return ft.local_sgd.LocalSGD(
196+
manager=ft_manager._manager,
197+
model=model,
198+
optimizer=optimizer,
199+
sync_every=sync_every,
200+
)
201+
else:
202+
raise ValueError(
203+
f"Unknown training method: {training_method}, only 'diloco' and 'local_sgd' are supported."
204+
)
205+
return nullcontext()

torchtitan/config_manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,19 @@ class FaultTolerance:
502502
min_replica_size: int = 1
503503
"""The minimum number of FT replica for each step."""
504504

505+
semi_sync_method: str | None = None
506+
"""
507+
The algorithm to use for semi-sync training. Currently, only "local_sgd" and "diloco" from
508+
torchft are supported
509+
(https://github.com/pytorch/torchft/blob/360c5c534bdeac959507e9d238ba9f3902d3fda9/torchft/local_sgd.py#L41)
510+
"""
511+
512+
sync_steps: int = 5
513+
"""
514+
Number of steps to wait before performing synchronization. This is only used when "semi_sync_method"
515+
is set.
516+
"""
517+
505518

506519
@dataclass
507520
class Experimental:

torchtitan/train.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,13 @@ def train(self):
399399
job_config, global_step=self.step
400400
) as torch_profiler, maybe_enable_memory_snapshot(
401401
job_config, global_step=self.step
402-
) as memory_profiler:
402+
) as memory_profiler, ft.maybe_semi_sync_training(
403+
job_config,
404+
ft_manager=self.ft_manager,
405+
model=self.model_parts[0],
406+
optimizer=self.optimizers,
407+
sync_every=job_config.fault_tolerance.sync_steps,
408+
):
403409
data_iterator = iter(self.dataloader)
404410
while self.step < job_config.training.steps:
405411
self.step += 1

0 commit comments

Comments
 (0)