|
58 | 58 | f" Exact error: {e}" |
59 | 59 | ) |
60 | 60 |
|
| 61 | +def configure_nccl_for_checkpointing(): |
| 62 | + """Configure NCCL settings for large model checkpoint saving.""" |
| 63 | + import os |
| 64 | + |
| 65 | + # Increase NCCL timeout for large model checkpointing |
| 66 | + if not os.environ.get('NCCL_ASYNC_ERROR_HANDLING'): |
| 67 | + os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '1' |
| 68 | + |
| 69 | + # Set longer timeout (1 hour) for checkpoint operations |
| 70 | + if not os.environ.get('NCCL_BLOCKING_WAIT'): |
| 71 | + os.environ['NCCL_BLOCKING_WAIT'] = '1' |
| 72 | + |
| 73 | + # Increase socket timeout |
| 74 | + if not os.environ.get('NCCL_SOCKET_TIMEOUT'): |
| 75 | + os.environ['NCCL_SOCKET_TIMEOUT'] = '3600000' |
61 | 76 |
|
62 | 77 | @contextmanager |
63 | 78 | def _debug_time(name: str): |
@@ -289,14 +304,34 @@ def save_checkpoint( |
289 | 304 |
|
290 | 305 | rank = torch.distributed.get_rank() |
291 | 306 | iteration = _get_iteration_from_checkpoint(checkpoint) |
| 307 | + |
| 308 | + # Configure NCCL for async saves to prevent timeouts |
| 309 | + if self.async_save: |
| 310 | + configure_nccl_for_checkpointing() |
| 311 | + |
292 | 312 | start_time = time() |
293 | | - async_save_request = dist_checkpointing.save( |
294 | | - sharded_state_dict=checkpoint, |
295 | | - checkpoint_dir=path, |
296 | | - sharded_strategy=self.save_sharded_strategy, |
297 | | - validate_access_integrity=validate_sharding_integrity, |
298 | | - async_sharded_save=self.async_save, |
299 | | - ) |
| 313 | + |
| 314 | + try: |
| 315 | + async_save_request = dist_checkpointing.save( |
| 316 | + sharded_state_dict=checkpoint, |
| 317 | + checkpoint_dir=path, |
| 318 | + sharded_strategy=self.save_sharded_strategy, |
| 319 | + validate_access_integrity=validate_sharding_integrity, |
| 320 | + async_sharded_save=self.async_save, |
| 321 | + ) |
| 322 | + except Exception as e: |
| 323 | + if self.async_save and ("timeout" in str(e).lower() or "nccl" in str(e).lower()): |
| 324 | + logging.warning(f"Async save failed with NCCL timeout ({e}), falling back to sync save") |
| 325 | + async_save_request = dist_checkpointing.save( |
| 326 | + sharded_state_dict=checkpoint, |
| 327 | + checkpoint_dir=path, |
| 328 | + sharded_strategy=self.save_sharded_strategy, |
| 329 | + validate_access_integrity=validate_sharding_integrity, |
| 330 | + async_sharded_save=False, |
| 331 | + ) |
| 332 | + else: |
| 333 | + raise |
| 334 | + |
300 | 335 | end_time = time() |
301 | 336 | log_parts = ( |
302 | 337 | "Global Checkpoint Save", |
|
0 commit comments