Skip to content

Commit b8fe07e

Browse files
authored
refactor diloco test (#232)
Summary: - move the training loop to a separate file - convert it into a class so that methods can be overridden without having to duplicate code
1 parent a0acd51 commit b8fe07e

File tree

2 files changed

+315
-179
lines changed

2 files changed

+315
-179
lines changed

torchft/local_sgd_integ_test.py

Lines changed: 5 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -33,38 +33,11 @@
3333
ProcessGroupBabyNCCL,
3434
ProcessGroupGloo,
3535
)
36+
from torchft.test.diloco_trainer import DiLoCoTrainer, MultiMyModel
3637

3738
logger: logging.Logger = logging.getLogger(__name__)
3839

3940

40-
class MultiMyModel(torch.nn.Module):
41-
def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
42-
super().__init__()
43-
self.in_dim = in_dim
44-
45-
self.layers = torch.nn.ModuleList()
46-
for i in range(n_layers):
47-
self.layers.append(MyModel(in_dim, out_dim))
48-
in_dim, out_dim = out_dim, in_dim
49-
50-
self.out_dim = in_dim
51-
52-
def forward(self, x: torch.Tensor) -> torch.Tensor:
53-
for layer in self.layers:
54-
x = layer(x)
55-
return x
56-
57-
def get_rand_inputs(
58-
self, batch_size: int, device: torch.device = torch.device("cpu")
59-
) -> torch.Tensor:
60-
return torch.rand(batch_size, self.in_dim, device=device)
61-
62-
def get_rand_labels(
63-
self, batch_size: int, device: torch.device = torch.device("cpu")
64-
) -> torch.Tensor:
65-
return torch.randint(self.out_dim, (batch_size,), device=device)
66-
67-
6841
def local_sgd_train_loop(
6942
rank: int,
7043
store_port: int,
@@ -148,158 +121,11 @@ def diloco_train_loop(
148121
diloco_args = train_loop_args.get("diloco_args", {})
149122

150123
with ExitStack() as stack:
151-
# Declare the model and optimizers
152-
m = MultiMyModel(2, 3, n_fragments)
153-
m.load_state_dict(model_state_dict)
154-
m.to(device)
155-
156-
# Setup optimizers
157-
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
158-
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
124+
trainer = DiLoCoTrainer(
125+
rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args
159126
)
160-
161-
# Create one outer optimizer per fragment
162-
outer_optimizers = []
163-
for _, layer in enumerate(m.layers):
164-
outer_optimizers.append(
165-
torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True)
166-
)
167-
168-
# pyre-ignore[53]
169-
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
170-
m.load_state_dict(state_dict["model"])
171-
m.to(device)
172-
173-
# Load original parameters for each fragment
174-
for i, fragment in enumerate(diloco._fragments):
175-
fragment.original_parameters = cast(
176-
Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"]
177-
)
178-
179-
for fragment in diloco._fragments:
180-
for name in fragment.original_parameters.keys():
181-
fragment.original_parameters[name] = fragment.original_parameters[
182-
name
183-
].to(device)
184-
185-
inner_optimizer.load_state_dict(state_dict["inner_optim"])
186-
for i, optimizer in enumerate(outer_optimizers):
187-
optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"])
188-
189-
def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
190-
return {
191-
"model": m.state_dict(),
192-
"original_params": {
193-
f"{i}": fragment.original_parameters
194-
for i, fragment in enumerate(diloco._fragments)
195-
},
196-
"inner_optim": inner_optimizer.state_dict(),
197-
"outer_optim": {
198-
f"{i}": optimizer.state_dict()
199-
for i, optimizer in enumerate(outer_optimizers)
200-
},
201-
}
202-
203-
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
204-
205-
if device.type == "cuda":
206-
pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL())
207-
else:
208-
pg = FakeProcessGroupWrapper(
209-
ProcessGroupGloo(timeout=timedelta(seconds=10))
210-
)
211-
manager = Manager(
212-
pg=pg,
213-
min_replica_size=2,
214-
use_async_quorum=False,
215-
load_state_dict=load_state_dict,
216-
state_dict=state_dict,
217-
replica_id=str(runner.replica_id),
218-
store_addr="localhost",
219-
store_port=store_port,
220-
rank=rank,
221-
world_size=runner.world_size,
222-
lighthouse_addr=runner.lighthouse_address,
223-
port=19530 + runner.replica_id,
224-
connect_timeout=timedelta(seconds=10),
225-
quorum_timeout=timedelta(seconds=10),
226-
timeout=timedelta(seconds=10),
227-
# pyre-fixme[6]: Incompatible parameter type
228-
**runner.manager_args,
229-
)
230-
runner.event_injector.set_pg(pg)
231-
stack.callback(manager.shutdown)
232-
# initialize default group for device mesh to work
233-
if not torch.distributed.is_initialized():
234-
# TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0
235-
try:
236-
torch.distributed.init_process_group(
237-
init_method="tcp://localhost:0",
238-
rank=rank,
239-
world_size=runner.world_size,
240-
)
241-
except ValueError:
242-
os.environ["MASTER_ADDR"] = "localhost"
243-
os.environ["MASTER_PORT"] = "0"
244-
os.environ["WORLD_SIZE"] = str(runner.world_size)
245-
os.environ["RANK"] = str(rank)
246-
247-
device_type = device.type
248-
ft_device_mesh = ft_init_device_mesh(
249-
device_type=device_type,
250-
mesh_shape=(runner.world_size, 1),
251-
mesh_dim_names=("replicate", "none"),
252-
replicate_dim=0,
253-
manager=manager,
254-
)
255-
for layer in m.layers:
256-
if isinstance(layer, nn.Linear):
257-
for param in layer.parameters():
258-
param = DTensor.from_local(
259-
param,
260-
device_mesh=ft_device_mesh,
261-
)
262-
263-
criterion = nn.CrossEntropyLoss()
264-
all_state_dicts = {}
265-
266-
if "sync_every" not in diloco_args:
267-
diloco_args["sync_every"] = 2
268-
269-
with DiLoCo(
270-
manager,
271-
[layer for layer in m.layers],
272-
inner_optimizer,
273-
outer_optimizers,
274-
backup_device=device,
275-
**diloco_args,
276-
) as diloco:
277-
while True:
278-
runner.event_injector.check(rank, manager.current_step())
279-
280-
manager_curr_step = manager.current_step()
281-
if manager_curr_step not in all_state_dicts:
282-
all_state_dicts[manager_curr_step] = copy.deepcopy(
283-
manager._manager_state_dict()
284-
)
285-
286-
batch_size = 1
287-
inputs = m.get_rand_inputs(batch_size, device=device)
288-
labels = m.get_rand_labels(batch_size, device=device)
289-
290-
out = m(inputs)
291-
loss = criterion(out, labels)
292-
293-
inner_optimizer.zero_grad()
294-
loss.backward()
295-
inner_optimizer.step()
296-
297-
# after 4 model updates then break
298-
if manager.current_step() >= 4:
299-
break
300-
301-
# return state_dict so we can check consistency
302-
return all_state_dicts
127+
stack.callback(trainer.manager.shutdown)
128+
return trainer.train_loop()
303129
return {}
304130

305131

0 commit comments

Comments
 (0)