From 5f01990859e4f4d178d580bc55504e44c3564a35 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Feb 2025 23:58:39 -0800 Subject: [PATCH] Disable async quorum for the first quorum sync If we don't wait for the first quorum, the trainer will continue to run forward and may use incorrect weights if the trainer is healing. --- torchft/manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchft/manager.py b/torchft/manager.py index 668189c2..967365eb 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -151,6 +151,7 @@ def __init__( self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout self._world_size_mode = world_size_mode + self._first_quorum = True store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) @@ -404,8 +405,11 @@ def start_quorum( shrink_only=shrink_only, quorum_timeout=timeout or self._quorum_timeout, ) - if not self._use_async_quorum: + # If this is the first quorum sync, we need to wait for the result. + # Otherwise, we may mistakenly perform the forward with incorrect weights. + if not self._use_async_quorum or self._first_quorum: self.wait_quorum() + self._first_quorum = False if self._healing: # eagerly apply pending state_dict so we can run the forwards pass