|
33 | 33 | ProcessGroupBabyNCCL,
|
34 | 34 | ProcessGroupGloo,
|
35 | 35 | )
|
| 36 | +from torchft.test.diloco_trainer import DiLoCoTrainer, MultiMyModel |
36 | 37 |
|
37 | 38 | logger: logging.Logger = logging.getLogger(__name__)
|
38 | 39 |
|
39 | 40 |
|
40 |
| -class MultiMyModel(torch.nn.Module): |
41 |
| - def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None: |
42 |
| - super().__init__() |
43 |
| - self.in_dim = in_dim |
44 |
| - |
45 |
| - self.layers = torch.nn.ModuleList() |
46 |
| - for i in range(n_layers): |
47 |
| - self.layers.append(MyModel(in_dim, out_dim)) |
48 |
| - in_dim, out_dim = out_dim, in_dim |
49 |
| - |
50 |
| - self.out_dim = in_dim |
51 |
| - |
52 |
| - def forward(self, x: torch.Tensor) -> torch.Tensor: |
53 |
| - for layer in self.layers: |
54 |
| - x = layer(x) |
55 |
| - return x |
56 |
| - |
57 |
| - def get_rand_inputs( |
58 |
| - self, batch_size: int, device: torch.device = torch.device("cpu") |
59 |
| - ) -> torch.Tensor: |
60 |
| - return torch.rand(batch_size, self.in_dim, device=device) |
61 |
| - |
62 |
| - def get_rand_labels( |
63 |
| - self, batch_size: int, device: torch.device = torch.device("cpu") |
64 |
| - ) -> torch.Tensor: |
65 |
| - return torch.randint(self.out_dim, (batch_size,), device=device) |
66 |
| - |
67 |
| - |
68 | 41 | def local_sgd_train_loop(
|
69 | 42 | rank: int,
|
70 | 43 | store_port: int,
|
@@ -148,158 +121,11 @@ def diloco_train_loop(
|
148 | 121 | diloco_args = train_loop_args.get("diloco_args", {})
|
149 | 122 |
|
150 | 123 | with ExitStack() as stack:
|
151 |
| - # Declare the model and optimizers |
152 |
| - m = MultiMyModel(2, 3, n_fragments) |
153 |
| - m.load_state_dict(model_state_dict) |
154 |
| - m.to(device) |
155 |
| - |
156 |
| - # Setup optimizers |
157 |
| - inner_optimizer: optim.Optimizer = torch.optim.AdamW( |
158 |
| - m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) |
| 124 | + trainer = DiLoCoTrainer( |
| 125 | + rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args |
159 | 126 | )
|
160 |
| - |
161 |
| - # Create one outer optimizer per fragment |
162 |
| - outer_optimizers = [] |
163 |
| - for _, layer in enumerate(m.layers): |
164 |
| - outer_optimizers.append( |
165 |
| - torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True) |
166 |
| - ) |
167 |
| - |
168 |
| - # pyre-ignore[53] |
169 |
| - def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: |
170 |
| - m.load_state_dict(state_dict["model"]) |
171 |
| - m.to(device) |
172 |
| - |
173 |
| - # Load original parameters for each fragment |
174 |
| - for i, fragment in enumerate(diloco._fragments): |
175 |
| - fragment.original_parameters = cast( |
176 |
| - Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"] |
177 |
| - ) |
178 |
| - |
179 |
| - for fragment in diloco._fragments: |
180 |
| - for name in fragment.original_parameters.keys(): |
181 |
| - fragment.original_parameters[name] = fragment.original_parameters[ |
182 |
| - name |
183 |
| - ].to(device) |
184 |
| - |
185 |
| - inner_optimizer.load_state_dict(state_dict["inner_optim"]) |
186 |
| - for i, optimizer in enumerate(outer_optimizers): |
187 |
| - optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"]) |
188 |
| - |
189 |
| - def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] |
190 |
| - return { |
191 |
| - "model": m.state_dict(), |
192 |
| - "original_params": { |
193 |
| - f"{i}": fragment.original_parameters |
194 |
| - for i, fragment in enumerate(diloco._fragments) |
195 |
| - }, |
196 |
| - "inner_optim": inner_optimizer.state_dict(), |
197 |
| - "outer_optim": { |
198 |
| - f"{i}": optimizer.state_dict() |
199 |
| - for i, optimizer in enumerate(outer_optimizers) |
200 |
| - }, |
201 |
| - } |
202 |
| - |
203 |
| - print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") |
204 |
| - |
205 |
| - if device.type == "cuda": |
206 |
| - pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL()) |
207 |
| - else: |
208 |
| - pg = FakeProcessGroupWrapper( |
209 |
| - ProcessGroupGloo(timeout=timedelta(seconds=10)) |
210 |
| - ) |
211 |
| - manager = Manager( |
212 |
| - pg=pg, |
213 |
| - min_replica_size=2, |
214 |
| - use_async_quorum=False, |
215 |
| - load_state_dict=load_state_dict, |
216 |
| - state_dict=state_dict, |
217 |
| - replica_id=str(runner.replica_id), |
218 |
| - store_addr="localhost", |
219 |
| - store_port=store_port, |
220 |
| - rank=rank, |
221 |
| - world_size=runner.world_size, |
222 |
| - lighthouse_addr=runner.lighthouse_address, |
223 |
| - port=19530 + runner.replica_id, |
224 |
| - connect_timeout=timedelta(seconds=10), |
225 |
| - quorum_timeout=timedelta(seconds=10), |
226 |
| - timeout=timedelta(seconds=10), |
227 |
| - # pyre-fixme[6]: Incompatible parameter type |
228 |
| - **runner.manager_args, |
229 |
| - ) |
230 |
| - runner.event_injector.set_pg(pg) |
231 |
| - stack.callback(manager.shutdown) |
232 |
| - # initialize default group for device mesh to work |
233 |
| - if not torch.distributed.is_initialized(): |
234 |
| - # TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0 |
235 |
| - try: |
236 |
| - torch.distributed.init_process_group( |
237 |
| - init_method="tcp://localhost:0", |
238 |
| - rank=rank, |
239 |
| - world_size=runner.world_size, |
240 |
| - ) |
241 |
| - except ValueError: |
242 |
| - os.environ["MASTER_ADDR"] = "localhost" |
243 |
| - os.environ["MASTER_PORT"] = "0" |
244 |
| - os.environ["WORLD_SIZE"] = str(runner.world_size) |
245 |
| - os.environ["RANK"] = str(rank) |
246 |
| - |
247 |
| - device_type = device.type |
248 |
| - ft_device_mesh = ft_init_device_mesh( |
249 |
| - device_type=device_type, |
250 |
| - mesh_shape=(runner.world_size, 1), |
251 |
| - mesh_dim_names=("replicate", "none"), |
252 |
| - replicate_dim=0, |
253 |
| - manager=manager, |
254 |
| - ) |
255 |
| - for layer in m.layers: |
256 |
| - if isinstance(layer, nn.Linear): |
257 |
| - for param in layer.parameters(): |
258 |
| - param = DTensor.from_local( |
259 |
| - param, |
260 |
| - device_mesh=ft_device_mesh, |
261 |
| - ) |
262 |
| - |
263 |
| - criterion = nn.CrossEntropyLoss() |
264 |
| - all_state_dicts = {} |
265 |
| - |
266 |
| - if "sync_every" not in diloco_args: |
267 |
| - diloco_args["sync_every"] = 2 |
268 |
| - |
269 |
| - with DiLoCo( |
270 |
| - manager, |
271 |
| - [layer for layer in m.layers], |
272 |
| - inner_optimizer, |
273 |
| - outer_optimizers, |
274 |
| - backup_device=device, |
275 |
| - **diloco_args, |
276 |
| - ) as diloco: |
277 |
| - while True: |
278 |
| - runner.event_injector.check(rank, manager.current_step()) |
279 |
| - |
280 |
| - manager_curr_step = manager.current_step() |
281 |
| - if manager_curr_step not in all_state_dicts: |
282 |
| - all_state_dicts[manager_curr_step] = copy.deepcopy( |
283 |
| - manager._manager_state_dict() |
284 |
| - ) |
285 |
| - |
286 |
| - batch_size = 1 |
287 |
| - inputs = m.get_rand_inputs(batch_size, device=device) |
288 |
| - labels = m.get_rand_labels(batch_size, device=device) |
289 |
| - |
290 |
| - out = m(inputs) |
291 |
| - loss = criterion(out, labels) |
292 |
| - |
293 |
| - inner_optimizer.zero_grad() |
294 |
| - loss.backward() |
295 |
| - inner_optimizer.step() |
296 |
| - |
297 |
| - # after 4 model updates then break |
298 |
| - if manager.current_step() >= 4: |
299 |
| - break |
300 |
| - |
301 |
| - # return state_dict so we can check consistency |
302 |
| - return all_state_dicts |
| 127 | + stack.callback(trainer.manager.shutdown) |
| 128 | + return trainer.train_loop() |
303 | 129 | return {}
|
304 | 130 |
|
305 | 131 |
|
|
0 commit comments