Skip to content

Commit dc1037e

Browse files
d4l3kdl541
andauthored
manager: Add option to skip initial sync (#159)
* add option to skip initial sync in Manager Summary: We currently always heal on step 0 to avoid synchronization issues. We want an option to support skipping this sync for users who set the PyTorch seed so all ranks are initialized with the same values. This diff added a init_sync boolean flag that can be passed from the manager client in python to the manager service in rust. If the manager service skips the sync depending on whether the init_sync is true. Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fixed some rust files and tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix init_sync logic Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Add tests for manager.rs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Added skip init_sync tests to python client Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * manager: final changes for init_sync --------- Co-authored-by: Dave Lei <[email protected]>
1 parent 9cc565c commit dc1037e

File tree

7 files changed

+170
-14
lines changed

7 files changed

+170
-14
lines changed

proto/torchft.proto

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ message ManagerQuorumRequest {
7676
int64 step = 2;
7777
string checkpoint_metadata = 3;
7878
bool shrink_only = 4;
79+
bool init_sync = 5;
7980
}
8081

8182
message ManagerQuorumResponse {

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ impl ManagerClient {
175175
step: i64,
176176
checkpoint_metadata: String,
177177
shrink_only: bool,
178+
init_sync: bool,
178179
timeout: Duration,
179180
) -> Result<QuorumResult, StatusError> {
180181
py.allow_threads(move || {
@@ -183,6 +184,7 @@ impl ManagerClient {
183184
step: step,
184185
checkpoint_metadata: checkpoint_metadata,
185186
shrink_only: shrink_only,
187+
init_sync: init_sync,
186188
});
187189

188190
// This timeout is processed on the server side so we also enable

src/manager.rs

+63-11
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ impl ManagerService for Arc<Manager> {
286286

287287
info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank);
288288

289-
let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?;
289+
let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?;
290290

291291
Ok(Response::new(reply))
292292
}
@@ -382,6 +382,7 @@ fn compute_quorum_results(
382382
replica_id: &str,
383383
rank: i64,
384384
quorum: &Quorum,
385+
init_sync: bool,
385386
) -> Result<ManagerQuorumResponse, Status> {
386387
let mut participants = quorum.participants.clone();
387388
participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id));
@@ -424,20 +425,23 @@ fn compute_quorum_results(
424425

425426
// Compute recovery assignments
426427

427-
// Nodes are recovering if:
428-
// 1. not at the max step
428+
let force_recover = init_sync && max_step == 0;
429+
430+
// Nodes are recovering if
431+
// 1. not at the max step (init_sync)
429432
// 2. max_step == 0 and not the primary replica
430433
let all_recover_dst_ranks: Vec<usize> = participants
431434
.iter()
432435
.enumerate()
433436
.filter_map(|(i, p)| {
434-
if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id {
437+
if p.step != max_step || force_recover && primary.replica_id != p.replica_id {
435438
Some(i)
436439
} else {
437440
None
438441
}
439442
})
440443
.collect();
444+
441445
let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::<HashSet<_>>();
442446
let up_to_date_ranks: Vec<usize> = participants
443447
.iter()
@@ -605,6 +609,7 @@ mod tests {
605609
step: 123,
606610
checkpoint_metadata: "addr".to_string(),
607611
shrink_only: false,
612+
init_sync: true,
608613
});
609614
request.set_timeout(Duration::from_secs(10));
610615
let resp = client.quorum(request).await?.into_inner();
@@ -664,6 +669,7 @@ mod tests {
664669
step: 0,
665670
checkpoint_metadata: "addr".to_string(),
666671
shrink_only: false,
672+
init_sync: true,
667673
});
668674
request.set_timeout(Duration::from_secs(10));
669675

@@ -771,21 +777,21 @@ mod tests {
771777

772778
// rank 0
773779

774-
let results = compute_quorum_results("replica_0", 0, &quorum)?;
780+
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
775781
assert!(!results.heal);
776782
assert_eq!(results.replica_rank, 0);
777783
assert_eq!(results.recover_src_rank, None);
778784
assert_eq!(results.recover_dst_ranks, vec![1]);
779785

780-
let results = compute_quorum_results("replica_1", 0, &quorum)?;
786+
let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
781787
assert!(results.heal);
782788
assert_eq!(results.replica_rank, 1);
783789
assert_eq!(results.recover_src_rank, Some(0));
784790
assert_eq!(results.recover_dst_ranks, Vec::<i64>::new());
785791

786792
// rank 1 assignments should be offset from rank 0 above and the primary
787793

788-
let results = compute_quorum_results("replica_1", 1, &quorum)?;
794+
let results = compute_quorum_results("replica_1", 1, &quorum, true)?;
789795
assert!(!results.heal);
790796
assert_eq!(results.replica_rank, 1);
791797
assert_eq!(results.recover_src_rank, None);
@@ -850,34 +856,80 @@ mod tests {
850856

851857
// rank 0
852858

853-
let results = compute_quorum_results("replica_0", 0, &quorum)?;
859+
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
854860
assert!(results.heal);
855861
assert_eq!(results.recover_src_manager_address, "addr_1".to_string());
856862
assert_eq!(results.replica_rank, 0);
857863
assert_eq!(results.recover_src_rank, Some(1));
858864
assert!(results.recover_dst_ranks.is_empty());
859865

860-
let results = compute_quorum_results("replica_1", 0, &quorum)?;
866+
let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
861867
assert!(!results.heal);
862868
assert_eq!(results.recover_src_manager_address, "".to_string());
863869
assert_eq!(results.replica_rank, 1);
864870
assert_eq!(results.recover_src_rank, None);
865871
assert_eq!(results.recover_dst_ranks, vec![0, 4]);
866872

867-
let results = compute_quorum_results("replica_3", 0, &quorum)?;
873+
let results = compute_quorum_results("replica_3", 0, &quorum, true)?;
868874
assert!(!results.heal);
869875
assert_eq!(results.replica_rank, 3);
870876
assert_eq!(results.recover_src_rank, None);
871877
assert_eq!(results.recover_dst_ranks, vec![2]);
872878

873879
// rank 1 assignments should be offset from rank 0 above
874880

875-
let results = compute_quorum_results("replica_1", 1, &quorum)?;
881+
let results = compute_quorum_results("replica_1", 1, &quorum, true)?;
876882
assert!(!results.heal);
877883
assert_eq!(results.replica_rank, 1);
878884
assert_eq!(results.recover_src_rank, None);
879885
assert_eq!(results.recover_dst_ranks, vec![2]);
880886

881887
Ok(())
882888
}
889+
890+
#[tokio::test]
891+
async fn test_compute_quorum_results_skip_init_sync() -> Result<()> {
892+
let mut quorum = Quorum {
893+
quorum_id: 1,
894+
participants: vec![
895+
QuorumMember {
896+
replica_id: "replica_0".to_string(),
897+
address: "addr_0".to_string(),
898+
store_address: "store_addr_0".to_string(),
899+
step: 0,
900+
world_size: 1,
901+
shrink_only: false,
902+
data: String::new(),
903+
},
904+
QuorumMember {
905+
replica_id: "replica_1".to_string(),
906+
address: "addr_1".to_string(),
907+
store_address: "store_addr_1".to_string(),
908+
step: 0,
909+
world_size: 1,
910+
shrink_only: false,
911+
data: String::new(),
912+
},
913+
],
914+
created: None,
915+
};
916+
917+
// baseline w/ init_sync=true
918+
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
919+
assert!(!results.heal);
920+
921+
let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
922+
assert!(results.heal);
923+
924+
// init_sync=false
925+
let results = compute_quorum_results("replica_1", 0, &quorum, false)?;
926+
assert!(!results.heal);
927+
928+
// init_sync=false, step=1
929+
quorum.participants[0].step = 1;
930+
let results = compute_quorum_results("replica_1", 0, &quorum, false)?;
931+
assert!(results.heal);
932+
933+
Ok(())
934+
}
883935
}

torchft/_torchft.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class ManagerClient:
1111
checkpoint_metadata: str,
1212
shrink_only: bool,
1313
timeout: timedelta,
14+
init_sync: bool = True,
1415
) -> QuorumResult: ...
1516
def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
1617
def should_commit(

torchft/manager.py

+6
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(
106106
hostname: str = socket.gethostname(),
107107
heartbeat_interval: timedelta = timedelta(milliseconds=100),
108108
checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
109+
init_sync: bool = True,
109110
) -> None:
110111
"""
111112
Args:
@@ -143,6 +144,9 @@ def __init__(
143144
hostname: if rank==0, the hostname to advertise to the lighthouse server
144145
checkpoint_transport: the checkpoint transport to use for
145146
transfering checkpoints to recovering replicas, defaults to HTTPTransport
147+
init_sync: whether to synchronize the model weights on step 0. If
148+
all of the model weights are initialized identically via
149+
``torch.set_seed`` you should set this to False.
146150
"""
147151
self._load_state_dict = load_state_dict
148152
self._user_state_dict = state_dict
@@ -152,6 +156,7 @@ def __init__(
152156
self._quorum_timeout = quorum_timeout
153157
self._connect_timeout = connect_timeout
154158
self._world_size_mode = world_size_mode
159+
self._init_sync = init_sync
155160

156161
store_addr = store_addr or os.environ["MASTER_ADDR"]
157162
store_port = store_port or int(os.environ["MASTER_PORT"])
@@ -455,6 +460,7 @@ def _async_quorum(
455460
checkpoint_metadata=self._checkpoint_transport.metadata(),
456461
shrink_only=shrink_only,
457462
timeout=quorum_timeout,
463+
init_sync=self._init_sync,
458464
)
459465

460466
quorum_id = quorum.quorum_id

torchft/manager_integ_test.py

+66-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
logger: logging.Logger = logging.getLogger(__name__)
2727

28+
INIT_LOCK: threading.Lock = threading.Lock()
29+
2830

2931
class MyModel(nn.Module):
3032
def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None:
@@ -191,7 +193,13 @@ def state_dict() -> Dict[str, Dict[str, object]]:
191193
)
192194
stack.callback(lambda: manager.shutdown(wait=False))
193195

194-
m: nn.Module = DistributedDataParallel(manager, MyModel())
196+
with INIT_LOCK:
197+
# We need to lock during init for testing init_sync=False as all
198+
# threads share the same RNG
199+
torch.manual_seed(42)
200+
m: nn.Module = MyModel()
201+
202+
m: nn.Module = DistributedDataParallel(manager, m)
195203
optimizer: optim.Optimizer = OptimizerWrapper(
196204
manager, optim.Adam(m.parameters())
197205
)
@@ -270,7 +278,11 @@ def test_ddp_healthy(self) -> None:
270278
),
271279
]
272280
)
273-
def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None:
281+
def test_ddp_recovery(
282+
self,
283+
name: str,
284+
use_async_quorum: bool,
285+
) -> None:
274286
lighthouse = LighthouseServer(
275287
bind="[::]:0",
276288
min_replicas=2,
@@ -302,7 +314,11 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None:
302314
state_dicts = []
303315

304316
for fut in as_completed(futures):
305-
state_dicts.append(fut.result())
317+
try:
318+
state_dicts.append(fut.result())
319+
except Exception as e:
320+
print(e)
321+
raise
306322

307323
lighthouse.shutdown()
308324

@@ -311,6 +327,53 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None:
311327

312328
self.assertEqual(failure_injectors[1].count, 1)
313329

330+
def test_ddp_skip_init_sync(
331+
self,
332+
) -> None:
333+
lighthouse = LighthouseServer(
334+
bind="[::]:0",
335+
min_replicas=2,
336+
)
337+
num_replicas = 2
338+
futures = []
339+
340+
# no failures
341+
failure_injectors = [
342+
FailureInjector(),
343+
FailureInjector(),
344+
]
345+
346+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
347+
for replica_id, failure_injector in zip(
348+
range(num_replicas), failure_injectors
349+
):
350+
runner = Runner(
351+
replica_id=replica_id,
352+
num_replicas=num_replicas,
353+
lighthouse_address=lighthouse.address(),
354+
failure_injector=failure_injector,
355+
manager_args={
356+
"use_async_quorum": False,
357+
"init_sync": False,
358+
},
359+
train_loop=ddp_train_loop,
360+
)
361+
futures.append(executor.submit(runner.run_replica))
362+
363+
state_dicts = []
364+
365+
for fut in as_completed(futures):
366+
try:
367+
state_dicts.append(fut.result())
368+
except Exception as e:
369+
print(e)
370+
raise
371+
372+
lighthouse.shutdown()
373+
374+
for state_dict in state_dicts:
375+
torch.testing.assert_close(state_dict, state_dicts[0])
376+
314377
def test_ddp_recovery_multi_rank(self) -> None:
315378
lighthouse = LighthouseServer(
316379
bind="[::]:0",

torchft/manager_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _create_manager(
4040
min_replica_size: int = 2,
4141
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
4242
timeout: timedelta = timedelta(seconds=10),
43+
init_sync: bool = True,
4344
) -> Manager:
4445
pg = create_autospec(ProcessGroup)
4546
pg.errored.return_value = None
@@ -67,6 +68,7 @@ def _create_manager(
6768
use_async_quorum=use_async_quorum,
6869
world_size_mode=world_size_mode,
6970
timeout=timeout,
71+
init_sync=init_sync,
7072
)
7173
self.manager = manager
7274
return manager
@@ -614,3 +616,32 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:
614616
client_mock().should_commit.call_args.kwargs["timeout"],
615617
timedelta(seconds=23),
616618
)
619+
620+
@patch("torchft.manager.ManagerClient", autospec=True)
621+
def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
622+
manager = self._create_manager(
623+
use_async_quorum=False,
624+
init_sync=False,
625+
)
626+
627+
self.assertFalse(manager._init_sync)
628+
629+
quorum = QuorumResult()
630+
quorum.quorum_id = 123
631+
quorum.replica_rank = 1
632+
quorum.replica_world_size = 2
633+
quorum.recover_src_manager_address = "manager address"
634+
quorum.store_address = f"localhost:{self.store.port}"
635+
quorum.max_step = 1
636+
quorum.max_rank = 1
637+
quorum.max_world_size = 2
638+
quorum.heal = False
639+
640+
client_mock()._quorum.return_value = quorum
641+
642+
manager.start_quorum()
643+
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False)
644+
645+
manager._init_sync = True
646+
manager.start_quorum()
647+
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)

0 commit comments

Comments
 (0)