Skip to content

Commit d0b506f

Browse files
ngoyal2707Naman Goyal
and
Naman Goyal
authored
added option for no PG validation for faster init (#1161)
Co-authored-by: Naman Goyal <[email protected]>
1 parent 3b7cc24 commit d0b506f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def __init__(
368368
gradient_predivide_factor: Optional[float] = None,
369369
limit_all_gather_events: bool = False,
370370
limit_reduce_scatter_events: bool = False,
371+
should_validate_process_group: bool = True,
371372
):
372373
try:
373374
import torch._C
@@ -451,7 +452,7 @@ def __init__(
451452
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
452453

453454
# skip validation if the process group was created above
454-
if process_group:
455+
if process_group and should_validate_process_group:
455456
validate_process_group(self.compute_device, self.process_group)
456457

457458
# enable pytorch sync_bn just in case model contains sync_bn layers.

0 commit comments

Comments
 (0)