Skip to content

Commit 1f15ca9

Browse files
committed
git commit -m "Fix NCCL timeout issues in async checkpoint saving
- Add configure_nccl_for_checkpointing() to set appropriate timeouts - Implement automatic fallback from async to sync on NCCL timeouts - Configure NCCL_SOCKET_TIMEOUT and error handling for large models Addresses NCCL timeout failures reported in #14576 when using ckpt_async_save=True with large models like Qwen3 235B" Signed-off-by: Abenezer <[email protected]>
1 parent a035e05 commit 1f15ca9

File tree

1 file changed

+42
-7
lines changed

1 file changed

+42
-7
lines changed

nemo/utils/callbacks/dist_ckpt_io.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@
5858
f" Exact error: {e}"
5959
)
6060

61+
def configure_nccl_for_checkpointing():
62+
"""Configure NCCL settings for large model checkpoint saving."""
63+
import os
64+
65+
# Increase NCCL timeout for large model checkpointing
66+
if not os.environ.get('NCCL_ASYNC_ERROR_HANDLING'):
67+
os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '1'
68+
69+
# Set longer timeout (1 hour) for checkpoint operations
70+
if not os.environ.get('NCCL_BLOCKING_WAIT'):
71+
os.environ['NCCL_BLOCKING_WAIT'] = '1'
72+
73+
# Increase socket timeout
74+
if not os.environ.get('NCCL_SOCKET_TIMEOUT'):
75+
os.environ['NCCL_SOCKET_TIMEOUT'] = '3600000'
6176

6277
@contextmanager
6378
def _debug_time(name: str):
@@ -289,14 +304,34 @@ def save_checkpoint(
289304

290305
rank = torch.distributed.get_rank()
291306
iteration = _get_iteration_from_checkpoint(checkpoint)
307+
308+
# Configure NCCL for async saves to prevent timeouts
309+
if self.async_save:
310+
configure_nccl_for_checkpointing()
311+
292312
start_time = time()
293-
async_save_request = dist_checkpointing.save(
294-
sharded_state_dict=checkpoint,
295-
checkpoint_dir=path,
296-
sharded_strategy=self.save_sharded_strategy,
297-
validate_access_integrity=validate_sharding_integrity,
298-
async_sharded_save=self.async_save,
299-
)
313+
314+
try:
315+
async_save_request = dist_checkpointing.save(
316+
sharded_state_dict=checkpoint,
317+
checkpoint_dir=path,
318+
sharded_strategy=self.save_sharded_strategy,
319+
validate_access_integrity=validate_sharding_integrity,
320+
async_sharded_save=self.async_save,
321+
)
322+
except Exception as e:
323+
if self.async_save and ("timeout" in str(e).lower() or "nccl" in str(e).lower()):
324+
logging.warning(f"Async save failed with NCCL timeout ({e}), falling back to sync save")
325+
async_save_request = dist_checkpointing.save(
326+
sharded_state_dict=checkpoint,
327+
checkpoint_dir=path,
328+
sharded_strategy=self.save_sharded_strategy,
329+
validate_access_integrity=validate_sharding_integrity,
330+
async_sharded_save=False,
331+
)
332+
else:
333+
raise
334+
300335
end_time = time()
301336
log_parts = (
302337
"Global Checkpoint Save",

0 commit comments

Comments
 (0)