Skip to content

Commit e008b65

Browse files
committed
Run local_sgd/diloco in titan
1 parent e0d2de6 commit e008b65

File tree

4 files changed

+163
-1
lines changed

4 files changed

+163
-1
lines changed

torchtitan/components/ft.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import importlib
9+
from contextlib import nullcontext
910
from dataclasses import dataclass
1011
from typing import Optional
1112

@@ -85,13 +86,16 @@ def init_ft_manager(job: JobConfig) -> FTManager:
8586

8687
pg = ft.ProcessGroupNCCL()
8788

89+
# If the training method is specific, then the quorum should be synchronous
90+
use_async_quorum = job.fault_tolerance.training_method is None
91+
8892
return FTManager(
8993
ft.Manager(
9094
pg=pg,
9195
min_replica_size=job.fault_tolerance.min_replica_size,
9296
load_state_dict=None,
9397
state_dict=None,
94-
use_async_quorum=True,
98+
use_async_quorum=use_async_quorum,
9599
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
96100
),
97101
group_size=job.fault_tolerance.group_size,
@@ -158,3 +162,44 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
158162
return DTensor.from_local(local_tensor, mesh.mesh, placements)
159163

160164
return total_norm
165+
166+
167+
def maybe_semi_sync_training(
168+
config: JobConfig,
169+
ft_manager: FTManager,
170+
model: torch.nn.Module,
171+
optimizer: torch.optim.Optimizer,
172+
sync_every: int,
173+
):
174+
"""
175+
If TorchFT is enabled and the config is set, use training_method
176+
"""
177+
training_method = config.fault_tolerance.training_method
178+
if training_method is not None:
179+
if training_method.lower() == "diloco":
180+
# Create the outer optimizer based on the inner optimizer parameters.
181+
params = [group["params"] for group in optimizer.param_groups]
182+
params = [param for sublist in params for param in sublist]
183+
outer_optimizer = torch.optim.SGD(
184+
params, lr=0.7, momentum=0.9, nesterov=True
185+
)
186+
187+
return ft.local_sgd.DiLoCo(
188+
manager=ft_manager._manager,
189+
model=model,
190+
inner_optimizer=optimizer,
191+
outer_optimizer=outer_optimizer,
192+
sync_every=sync_every,
193+
)
194+
elif training_method.lower() == "local_sgd":
195+
return ft.local_sgd.LocalSGD(
196+
manager=ft_manager._manager,
197+
model=model,
198+
optimizer=optimizer,
199+
sync_every=sync_every,
200+
)
201+
else:
202+
raise ValueError(
203+
f"Unknown training method: {training_method}, only 'diloco' and 'local_sgd' are supported."
204+
)
205+
return nullcontext()

torchtitan/config_manager.py

+7
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,13 @@ class FaultTolerance:
502502
min_replica_size: int = 1
503503
"""The minimum number of FT replica for each step."""
504504

505+
training_method: str | None = "diloco"
506+
"""
507+
The algorithm to use for semi-sync training. Currently, only "local_sgd" and "diloco" from
508+
torchft are supported
509+
(https://github.com/pytorch/torchft/blob/360c5c534bdeac959507e9d238ba9f3902d3fda9/torchft/local_sgd.py#L41)
510+
"""
511+
505512

506513
@dataclass
507514
class Experimental:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This folder contains experiments of running TorchTitan alongside of TorchFt (https://github.com/pytorch/torchft)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import time
9+
from datetime import timedelta
10+
from typing import Optional
11+
12+
import torch
13+
14+
import torchtitan.components.ft as ft
15+
16+
from torch.distributed.elastic.multiprocessing.errors import record
17+
from torchtitan.config_manager import JobConfig
18+
from torchtitan.distributed import utils as dist_utils
19+
from torchtitan.tools.logging import init_logger, logger
20+
from torchtitan.tools.profiling import (
21+
maybe_enable_memory_snapshot,
22+
maybe_enable_profiling,
23+
)
24+
from torchtitan.train import Trainer
25+
26+
27+
class FtTrainer(Trainer):
28+
# Override the train method to add fault tolerance
29+
@record
30+
def train(self):
31+
job_config = self.job_config
32+
33+
self.checkpointer.load(step=job_config.checkpoint.load_step)
34+
logger.info(f"Training starts at step {self.step + 1}.")
35+
36+
with maybe_enable_profiling(
37+
job_config, global_step=self.step
38+
) as torch_profiler, maybe_enable_memory_snapshot(
39+
job_config, global_step=self.step
40+
) as memory_profiler, ft.maybe_semi_sync_training(
41+
job_config,
42+
ft_manager=self.ft_manager,
43+
model=self.model_parts[0],
44+
optimizer=self.optimizers,
45+
sync_every=2,
46+
) as semi_sync_training:
47+
data_iterator = iter(self.dataloader)
48+
49+
while self.step < job_config.training.steps:
50+
self.step += 1
51+
self.gc_handler.run(self.step)
52+
inputs, labels = self.next_batch(data_iterator)
53+
self.train_step(inputs, labels)
54+
self.checkpointer.save(
55+
self.step, force=(self.step == job_config.training.steps)
56+
)
57+
58+
# signal the profiler that the next profiling step has started
59+
if torch_profiler:
60+
torch_profiler.step()
61+
if memory_profiler:
62+
memory_profiler.step()
63+
64+
# reduce timeout after first train step for faster signal
65+
# (assuming lazy init and compilation are finished)
66+
if self.step == 1:
67+
dist_utils.set_pg_timeouts(
68+
timeout=timedelta(
69+
seconds=job_config.comm.train_timeout_seconds
70+
),
71+
world_mesh=self.world_mesh,
72+
)
73+
74+
if torch.distributed.get_rank() == 0:
75+
logger.info("Sleeping 2 seconds for other ranks to complete")
76+
time.sleep(2)
77+
78+
self.metrics_processor.close()
79+
logger.info("Training completed")
80+
81+
82+
if __name__ == "__main__":
83+
init_logger()
84+
config = JobConfig()
85+
config.maybe_add_custom_args()
86+
config.parse_args()
87+
trainer: Optional[Trainer] = None
88+
89+
try:
90+
trainer = FtTrainer(config)
91+
92+
if config.checkpoint.create_seed_checkpoint:
93+
assert int(
94+
os.environ["WORLD_SIZE"]
95+
), "Must create seed checkpoint using a single device, to disable sharding."
96+
assert (
97+
config.checkpoint.enable_checkpoint
98+
), "Must enable checkpointing when creating a seed checkpoint."
99+
trainer.checkpointer.save(curr_step=0, force=True)
100+
logger.info("Created seed checkpoint")
101+
else:
102+
trainer.train()
103+
finally:
104+
if trainer:
105+
trainer.close()
106+
107+
if torch.distributed.is_initialized():
108+
torch.distributed.destroy_process_group()
109+
logger.info("Process group destroyed.")

0 commit comments

Comments
 (0)