Skip to content

[FT] Support local_sgd / diloco in titan #1122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions torchtitan/components/ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import copy
import importlib
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Optional
from typing import ContextManager, Optional, TYPE_CHECKING, Union

import torch
import torch.distributed as dist
Expand All @@ -22,6 +23,9 @@
if importlib.util.find_spec("torchft") is not None:
import torchft as ft

if TYPE_CHECKING:
from torchft import local_sgd

has_torchft = True
else:
has_torchft = False
Expand Down Expand Up @@ -85,13 +89,16 @@ def init_ft_manager(job: JobConfig) -> FTManager:

pg = ft.ProcessGroupNCCL()

# If the training method is specific, then the quorum should be synchronous
use_async_quorum = job.fault_tolerance.semi_sync_method is None

return FTManager(
ft.Manager(
pg=pg,
min_replica_size=job.fault_tolerance.min_replica_size,
load_state_dict=None,
state_dict=None,
use_async_quorum=True,
use_async_quorum=use_async_quorum,
replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
),
group_size=job.fault_tolerance.group_size,
Expand Down Expand Up @@ -158,3 +165,50 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
return DTensor.from_local(local_tensor, mesh.mesh, placements)

return total_norm


def maybe_semi_sync_training(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add typing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

config: JobConfig,
ft_manager: FTManager,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
sync_every: int,
) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
"""
If TorchFT is enabled and the config is set, use semi_sync_method
"""
semi_sync_method = config.fault_tolerance.semi_sync_method
torchft_enabled = config.fault_tolerance.enable
if torchft_enabled and semi_sync_method is not None:
from torchft import local_sgd

assert (
ft_manager._manager is not None
), "FTManager must be enabled to use semi-sync training."
if semi_sync_method.lower() == "diloco":
# Create the outer optimizer based on the inner optimizer parameters.
params = [group["params"] for group in optimizer.param_groups]
params = [param for sublist in params for param in sublist]
outer_optimizer = torch.optim.SGD(
params, lr=0.7, momentum=0.9, nesterov=True
)

return local_sgd.DiLoCo(
manager=ft_manager._manager,
model=model,
inner_optimizer=optimizer,
outer_optimizer=outer_optimizer,
sync_every=sync_every,
)
elif semi_sync_method.lower() == "local_sgd":
return local_sgd.LocalSGD(
manager=ft_manager._manager,
model=model,
optimizer=optimizer,
sync_every=sync_every,
)
else:
raise ValueError(
f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported."
)
return nullcontext()
13 changes: 13 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,19 @@ class FaultTolerance:
min_replica_size: int = 1
"""The minimum number of FT replica for each step."""

semi_sync_method: str | None = None
"""
The algorithm to use for semi-sync training. Currently, only "local_sgd" and "diloco" from
torchft are supported
(https://github.com/pytorch/torchft/blob/360c5c534bdeac959507e9d238ba9f3902d3fda9/torchft/local_sgd.py#L41)
"""

sync_steps: int = 5
"""
Number of steps to wait before performing synchronization. This is only used when "semi_sync_method"
is set.
"""


@dataclass
class Experimental:
Expand Down
8 changes: 7 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,13 @@ def train(self):
job_config, global_step=self.step
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=self.step
) as memory_profiler:
) as memory_profiler, ft.maybe_semi_sync_training(
job_config,
ft_manager=self.ft_manager,
model=self.model_parts[0],
optimizer=self.optimizers,
sync_every=job_config.fault_tolerance.sync_steps,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I just realized that this context can be initialized in trainer.__init__(), putting into dist_utils.get_train_context. If my understanding is correct, that will only require to change dist_utils.get_train_context and its caller. Let me know if this makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's entering context manager for the entire training, across iterations.
dist_utils.get_train_context is a per-iteration context manager.
Which way are we supposed to use the ft contexts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting. Yeah the FT context is across iterations since it adds hooks to the optimizers and performs synchronization every N iterations (based on the optimizer .step() calls).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Maybe it'd be good to organize them into two context manager util functions, one for overall train, the other for per training iteration.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, I will add in a follow up PR!

data_iterator = iter(self.dataloader)
while self.step < job_config.training.steps:
self.step += 1
Expand Down
Loading