|
6 | 6 |
|
7 | 7 | import copy
|
8 | 8 | import importlib
|
| 9 | +from contextlib import nullcontext |
9 | 10 | from dataclasses import dataclass
|
10 | 11 | from typing import Optional
|
11 | 12 |
|
@@ -85,13 +86,16 @@ def init_ft_manager(job: JobConfig) -> FTManager:
|
85 | 86 |
|
86 | 87 | pg = ft.ProcessGroupNCCL()
|
87 | 88 |
|
| 89 | + # If the training method is specific, then the quorum should be synchronous |
| 90 | + use_async_quorum = job.fault_tolerance.training_method is None |
| 91 | + |
88 | 92 | return FTManager(
|
89 | 93 | ft.Manager(
|
90 | 94 | pg=pg,
|
91 | 95 | min_replica_size=job.fault_tolerance.min_replica_size,
|
92 | 96 | load_state_dict=None,
|
93 | 97 | state_dict=None,
|
94 |
| - use_async_quorum=True, |
| 98 | + use_async_quorum=use_async_quorum, |
95 | 99 | replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}",
|
96 | 100 | ),
|
97 | 101 | group_size=job.fault_tolerance.group_size,
|
@@ -158,3 +162,44 @@ def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor:
|
158 | 162 | return DTensor.from_local(local_tensor, mesh.mesh, placements)
|
159 | 163 |
|
160 | 164 | return total_norm
|
| 165 | + |
| 166 | + |
| 167 | +def maybe_semi_sync_training( |
| 168 | + config: JobConfig, |
| 169 | + ft_manager: FTManager, |
| 170 | + model: torch.nn.Module, |
| 171 | + optimizer: torch.optim.Optimizer, |
| 172 | + sync_every: int, |
| 173 | +): |
| 174 | + """ |
| 175 | + If TorchFT is enabled and the config is set, use training_method |
| 176 | + """ |
| 177 | + training_method = config.fault_tolerance.training_method |
| 178 | + if training_method is not None: |
| 179 | + if training_method.lower() == "diloco": |
| 180 | + # Create the outer optimizer based on the inner optimizer parameters. |
| 181 | + params = [group["params"] for group in optimizer.param_groups] |
| 182 | + params = [param for sublist in params for param in sublist] |
| 183 | + outer_optimizer = torch.optim.SGD( |
| 184 | + params, lr=0.7, momentum=0.9, nesterov=True |
| 185 | + ) |
| 186 | + |
| 187 | + return ft.local_sgd.DiLoCo( |
| 188 | + manager=ft_manager._manager, |
| 189 | + model=model, |
| 190 | + inner_optimizer=optimizer, |
| 191 | + outer_optimizer=outer_optimizer, |
| 192 | + sync_every=sync_every, |
| 193 | + ) |
| 194 | + elif training_method.lower() == "local_sgd": |
| 195 | + return ft.local_sgd.LocalSGD( |
| 196 | + manager=ft_manager._manager, |
| 197 | + model=model, |
| 198 | + optimizer=optimizer, |
| 199 | + sync_every=sync_every, |
| 200 | + ) |
| 201 | + else: |
| 202 | + raise ValueError( |
| 203 | + f"Unknown training method: {training_method}, only 'diloco' and 'local_sgd' are supported." |
| 204 | + ) |
| 205 | + return nullcontext() |
0 commit comments