From 9ef0785e56f3299b0a65c4034cbe18cacbc4b863 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Sat, 8 Mar 2025 17:03:57 -0800 Subject: [PATCH 1/6] add option to skip initial sync in Manager Summary: We currently always heal on step 0 to avoid synchronization issues. We want an option to support skipping this sync for users who set the PyTorch seed so all ranks are initialized with the same values. This diff added a init_sync boolean flag that can be passed from the manager client in python to the manager service in rust. If the manager service skips the sync depending on whether the init_sync is true. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- proto/torchft.proto | 1 + src/manager.rs | 5 ++++- torchft/_torchft.pyi | 1 + torchft/manager.py | 3 ++- 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index c4d0a81b..7ffcaec6 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -76,6 +76,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; + optional bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/manager.rs b/src/manager.rs index 08d0cc28..bc9654f9 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -431,7 +431,10 @@ fn compute_quorum_results( .iter() .enumerate() .filter_map(|(i, p)| { - if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + if init_sync + || p.step != max_step + || max_step == 0 && primary.replica_id != p.replica_id + { Some(i) } else { None diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 49bdcddd..1a99913e 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,6 +11,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, + init_sync: Optional[bool] = False, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( diff --git a/torchft/manager.py b/torchft/manager.py index 0697bd4d..426fb35f 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -34,7 +34,7 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast +from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar import torch from torch.distributed import ReduceOp, TCPStore @@ -455,6 +455,7 @@ def _async_quorum( checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, + init_sync=self.init_sync, ) quorum_id = quorum.quorum_id From e768c0ab0ae7afe8ec45de1903ea01e62c8d4f18 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Sat, 8 Mar 2025 22:14:10 -0700 Subject: [PATCH 2/6] Fixed some rust files and tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- src/lib.rs | 1 + src/manager.rs | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a29e00d5..00254e29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -183,6 +183,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, + init_sync: Some(false), }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index bc9654f9..48db5a99 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -286,7 +286,12 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; + let reply = compute_quorum_results( + &self.replica_id, + rank, + &quorum, + req.init_sync.unwrap_or_default(), + )?; Ok(Response::new(reply)) } @@ -382,6 +387,7 @@ fn compute_quorum_results( replica_id: &str, rank: i64, quorum: &Quorum, + init_sync: bool, ) -> Result { let mut participants = quorum.participants.clone(); participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); @@ -608,6 +614,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: Some(false), }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -667,6 +674,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: Some(false), }); request.set_timeout(Duration::from_secs(10)); @@ -774,13 +782,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -788,7 +796,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -853,21 +861,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum)?; + let results = compute_quorum_results("replica_3", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -875,7 +883,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); From f7794a2749a5f0882310c079a66289986ad1321d Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Mon, 10 Mar 2025 07:42:45 -0700 Subject: [PATCH 3/6] Fix init_sync logic Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- proto/torchft.proto | 2 +- src/lib.rs | 2 +- src/manager.rs | 61 +++++++++++++++++++++----------------------- torchft/_torchft.pyi | 2 +- 4 files changed, 32 insertions(+), 35 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index 7ffcaec6..15a96d0f 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -76,7 +76,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; - optional bool init_sync = 5; + bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index 00254e29..68c81bfe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -183,7 +183,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, - init_sync: Some(false), + init_sync: true, }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index 48db5a99..7b317940 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -286,12 +286,7 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results( - &self.replica_id, - rank, - &quorum, - req.init_sync.unwrap_or_default(), - )?; + let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?; Ok(Response::new(reply)) } @@ -430,23 +425,25 @@ fn compute_quorum_results( // Compute recovery assignments - // Nodes are recovering if: - // 1. not at the max step - // 2. max_step == 0 and not the primary replica - let all_recover_dst_ranks: Vec = participants - .iter() - .enumerate() - .filter_map(|(i, p)| { - if init_sync - || p.step != max_step - || max_step == 0 && primary.replica_id != p.replica_id - { - Some(i) - } else { - None - } - }) - .collect(); + let all_recover_dst_ranks = if init_sync { + // Nodes are recovering if + // 1. not at the max step + // 2. max_step == 0 and not the primary replica + participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect() + } else { + Vec::::new() + }; + let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants .iter() @@ -614,7 +611,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, - init_sync: Some(false), + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -674,7 +671,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, - init_sync: Some(false), + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); @@ -782,13 +779,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -796,7 +793,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -861,21 +858,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + let results = compute_quorum_results("replica_3", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -883,7 +880,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 1a99913e..fdbd1fa3 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,7 +11,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - init_sync: Optional[bool] = False, + init_sync: bool = True, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( From ab05ae7372f97235d7b6c240f887f35c8bf48b3a Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Tue, 11 Mar 2025 09:23:21 -0700 Subject: [PATCH 4/6] Add tests for manager.rs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- src/manager.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/manager.rs b/src/manager.rs index 7b317940..519e7ca6 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -888,4 +888,86 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_compute_quorum_results_skip_init_sync() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_2".to_string(), + address: "addr_2".to_string(), + store_address: "store_addr_2".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_3".to_string(), + address: "addr_3".to_string(), + store_address: "store_addr_3".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_4".to_string(), + address: "addr_4".to_string(), + store_address: "store_addr_4".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 3); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + Ok(()) + } } From c5940880a589e4b6bad22b6d2c278cf1a59d9eb5 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Wed, 12 Mar 2025 11:48:37 -0700 Subject: [PATCH 5/6] Added skip init_sync tests to python client Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchft/manager_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index fb134967..954b88aa 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -614,3 +614,35 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: client_mock().should_commit.call_args.kwargs["timeout"], timedelta(seconds=23), ) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_skip_init(self, client_mock: MagicMock) -> None: + manager = self._create_manager(use_async_quorum=False) + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock()._quorum.return_value = quorum + + manager.start_quorum() + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=True) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=False) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], False + ) From e06c6f278263604213246b255d52ea1b9e2deba0 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 8 Apr 2025 14:31:08 -0700 Subject: [PATCH 6/6] manager: final changes for init_sync --- src/lib.rs | 3 +- src/manager.rs | 96 +++++++++++------------------------ torchft/manager.py | 9 +++- torchft/manager_integ_test.py | 69 +++++++++++++++++++++++-- torchft/manager_test.py | 25 +++++---- 5 files changed, 116 insertions(+), 86 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 68c81bfe..5ef1bcfa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -175,6 +175,7 @@ impl ManagerClient { step: i64, checkpoint_metadata: String, shrink_only: bool, + init_sync: bool, timeout: Duration, ) -> Result { py.allow_threads(move || { @@ -183,7 +184,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, - init_sync: true, + init_sync: init_sync, }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index 519e7ca6..358ff7a1 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -425,24 +425,22 @@ fn compute_quorum_results( // Compute recovery assignments - let all_recover_dst_ranks = if init_sync { - // Nodes are recovering if - // 1. not at the max step - // 2. max_step == 0 and not the primary replica - participants - .iter() - .enumerate() - .filter_map(|(i, p)| { - if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { - Some(i) - } else { - None - } - }) - .collect() - } else { - Vec::::new() - }; + let force_recover = init_sync && max_step == 0; + + // Nodes are recovering if + // 1. not at the max step (init_sync) + // 2. max_step == 0 and not the primary replica + let all_recover_dst_ranks: Vec = participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || force_recover && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect(); let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants @@ -891,7 +889,7 @@ mod tests { #[tokio::test] async fn test_compute_quorum_results_skip_init_sync() -> Result<()> { - let quorum = Quorum { + let mut quorum = Quorum { quorum_id: 1, participants: vec![ QuorumMember { @@ -901,72 +899,36 @@ mod tests { step: 0, world_size: 1, shrink_only: false, + data: String::new(), }, QuorumMember { replica_id: "replica_1".to_string(), address: "addr_1".to_string(), store_address: "store_addr_1".to_string(), - step: 1, - world_size: 1, - shrink_only: false, - }, - QuorumMember { - replica_id: "replica_2".to_string(), - address: "addr_2".to_string(), - store_address: "store_addr_2".to_string(), - step: 0, - world_size: 1, - shrink_only: false, - }, - QuorumMember { - replica_id: "replica_3".to_string(), - address: "addr_3".to_string(), - store_address: "store_addr_3".to_string(), - step: 1, - world_size: 1, - shrink_only: false, - }, - QuorumMember { - replica_id: "replica_4".to_string(), - address: "addr_4".to_string(), - store_address: "store_addr_4".to_string(), step: 0, world_size: 1, shrink_only: false, + data: String::new(), }, ], created: None, }; - // rank 0 - - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + // baseline w/ init_sync=true + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 0); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; - assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; + assert!(results.heal); - let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + // init_sync=false + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 3); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; - assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); + // init_sync=false, step=1 + quorum.participants[0].step = 1; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(results.heal); Ok(()) } diff --git a/torchft/manager.py b/torchft/manager.py index 426fb35f..fa2760d8 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -34,7 +34,7 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch from torch.distributed import ReduceOp, TCPStore @@ -106,6 +106,7 @@ def __init__( hostname: str = socket.gethostname(), heartbeat_interval: timedelta = timedelta(milliseconds=100), checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, + init_sync: bool = True, ) -> None: """ Args: @@ -143,6 +144,9 @@ def __init__( hostname: if rank==0, the hostname to advertise to the lighthouse server checkpoint_transport: the checkpoint transport to use for transfering checkpoints to recovering replicas, defaults to HTTPTransport + init_sync: whether to synchronize the model weights on step 0. If + all of the model weights are initialized identically via + ``torch.set_seed`` you should set this to False. """ self._load_state_dict = load_state_dict self._user_state_dict = state_dict @@ -152,6 +156,7 @@ def __init__( self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout self._world_size_mode = world_size_mode + self._init_sync = init_sync store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) @@ -455,7 +460,7 @@ def _async_quorum( checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, - init_sync=self.init_sync, + init_sync=self._init_sync, ) quorum_id = quorum.quorum_id diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index d591d0d2..e7622be0 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -25,6 +25,8 @@ logger: logging.Logger = logging.getLogger(__name__) +INIT_LOCK: threading.Lock = threading.Lock() + class MyModel(nn.Module): def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None: @@ -191,7 +193,13 @@ def state_dict() -> Dict[str, Dict[str, object]]: ) stack.callback(lambda: manager.shutdown(wait=False)) - m: nn.Module = DistributedDataParallel(manager, MyModel()) + with INIT_LOCK: + # We need to lock during init for testing init_sync=False as all + # threads share the same RNG + torch.manual_seed(42) + m: nn.Module = MyModel() + + m: nn.Module = DistributedDataParallel(manager, m) optimizer: optim.Optimizer = OptimizerWrapper( manager, optim.Adam(m.parameters()) ) @@ -270,7 +278,11 @@ def test_ddp_healthy(self) -> None: ), ] ) - def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: + def test_ddp_recovery( + self, + name: str, + use_async_quorum: bool, + ) -> None: lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, @@ -302,7 +314,11 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: state_dicts = [] for fut in as_completed(futures): - state_dicts.append(fut.result()) + try: + state_dicts.append(fut.result()) + except Exception as e: + print(e) + raise lighthouse.shutdown() @@ -311,6 +327,53 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: self.assertEqual(failure_injectors[1].count, 1) + def test_ddp_skip_init_sync( + self, + ) -> None: + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + # no failures + failure_injectors = [ + FailureInjector(), + FailureInjector(), + ] + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id, failure_injector in zip( + range(num_replicas), failure_injectors + ): + runner = Runner( + replica_id=replica_id, + num_replicas=num_replicas, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + manager_args={ + "use_async_quorum": False, + "init_sync": False, + }, + train_loop=ddp_train_loop, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + + for state_dict in state_dicts: + torch.testing.assert_close(state_dict, state_dicts[0]) + def test_ddp_recovery_multi_rank(self) -> None: lighthouse = LighthouseServer( bind="[::]:0", diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 954b88aa..2d421e61 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -40,6 +40,7 @@ def _create_manager( min_replica_size: int = 2, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, timeout: timedelta = timedelta(seconds=10), + init_sync: bool = True, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -67,6 +68,7 @@ def _create_manager( use_async_quorum=use_async_quorum, world_size_mode=world_size_mode, timeout=timeout, + init_sync=init_sync, ) self.manager = manager return manager @@ -617,7 +619,12 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_skip_init(self, client_mock: MagicMock) -> None: - manager = self._create_manager(use_async_quorum=False) + manager = self._create_manager( + use_async_quorum=False, + init_sync=False, + ) + + self.assertFalse(manager._init_sync) quorum = QuorumResult() quorum.quorum_id = 123 @@ -633,16 +640,8 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None: client_mock()._quorum.return_value = quorum manager.start_quorum() - self.assertEqual( - client_mock()._quorum.call_args.kwargs["init_sync"], True - ) + self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False) - manager.start_quorum(init_sync=True) - self.assertEqual( - client_mock()._quorum.call_args.kwargs["init_sync"], True - ) - - manager.start_quorum(init_sync=False) - self.assertEqual( - client_mock()._quorum.call_args.kwargs["init_sync"], False - ) + manager._init_sync = True + manager.start_quorum() + self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)