|
6 | 6 |
|
7 | 7 | import copy
|
8 | 8 | import importlib
|
| 9 | +from contextlib import nullcontext |
9 | 10 | from dataclasses import dataclass
|
10 |
| -from typing import Optional |
| 11 | +from typing import ContextManager, Optional, TYPE_CHECKING, Union |
11 | 12 |
|
12 | 13 | import torch
|
13 | 14 | import torch.distributed as dist
|
|
22 | 23 | if importlib.util.find_spec("torchft") is not None:
|
23 | 24 | import torchft as ft
|
24 | 25 |
|
| 26 | + if TYPE_CHECKING: |
| 27 | + from torchft import local_sgd |
| 28 | + |
25 | 29 | has_torchft = True
|
26 | 30 | else:
|
27 | 31 | has_torchft = False
|
@@ -85,13 +89,16 @@ def init_ft_manager(job: JobConfig) -> FTManager:
|
85 | 89 |
|
86 | 90 | pg = ft.ProcessGroupNCCL()
|
87 | 91 |
|
| 92 | + # If the training method is specific, then the quorum should be synchronous |
| 93 | + use_async_quorum = job.fault_tolerance.semi_sync_method is None |
| 94 | + |
88 | 95 | return FTManager(
|
89 | 96 | ft.Manager(
|
90 | 97 | pg=pg,
|
91 | 98 | min_replica_size=job.fault_tolerance.min_replica_size,
|
92 | 99 | load_state_dict=None,
|
93 | 100 | state_dict=None,
|
94 |
| - use_async_quorum=True, |
| 101 | + use_async_quorum=use_async_quorum, |
95 | 102 | replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
|
96 | 103 | ),
|
97 | 104 | group_size=job.fault_tolerance.group_size,
|
@@ -158,3 +165,50 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
|
158 | 165 | return DTensor.from_local(local_tensor, mesh.mesh, placements)
|
159 | 166 |
|
160 | 167 | return total_norm
|
| 168 | + |
| 169 | + |
| 170 | +def maybe_semi_sync_training( |
| 171 | + config: JobConfig, |
| 172 | + ft_manager: FTManager, |
| 173 | + model: torch.nn.Module, |
| 174 | + optimizer: torch.optim.Optimizer, |
| 175 | + sync_every: int, |
| 176 | +) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]: |
| 177 | + """ |
| 178 | + If TorchFT is enabled and the config is set, use semi_sync_method |
| 179 | + """ |
| 180 | + semi_sync_method = config.fault_tolerance.semi_sync_method |
| 181 | + torchft_enabled = config.fault_tolerance.enable |
| 182 | + if torchft_enabled and semi_sync_method is not None: |
| 183 | + from torchft import local_sgd |
| 184 | + |
| 185 | + assert ( |
| 186 | + ft_manager._manager is not None |
| 187 | + ), "FTManager must be enabled to use semi-sync training." |
| 188 | + if semi_sync_method.lower() == "diloco": |
| 189 | + # Create the outer optimizer based on the inner optimizer parameters. |
| 190 | + params = [group["params"] for group in optimizer.param_groups] |
| 191 | + params = [param for sublist in params for param in sublist] |
| 192 | + outer_optimizer = torch.optim.SGD( |
| 193 | + params, lr=0.7, momentum=0.9, nesterov=True |
| 194 | + ) |
| 195 | + |
| 196 | + return local_sgd.DiLoCo( |
| 197 | + manager=ft_manager._manager, |
| 198 | + model=model, |
| 199 | + inner_optimizer=optimizer, |
| 200 | + outer_optimizer=outer_optimizer, |
| 201 | + sync_every=sync_every, |
| 202 | + ) |
| 203 | + elif semi_sync_method.lower() == "local_sgd": |
| 204 | + return local_sgd.LocalSGD( |
| 205 | + manager=ft_manager._manager, |
| 206 | + model=model, |
| 207 | + optimizer=optimizer, |
| 208 | + sync_every=sync_every, |
| 209 | + ) |
| 210 | + else: |
| 211 | + raise ValueError( |
| 212 | + f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported." |
| 213 | + ) |
| 214 | + return nullcontext() |
0 commit comments