We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3b7cc24 commit d0b506fCopy full SHA for d0b506f
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
@@ -368,6 +368,7 @@ def __init__(
368
gradient_predivide_factor: Optional[float] = None,
369
limit_all_gather_events: bool = False,
370
limit_reduce_scatter_events: bool = False,
371
+ should_validate_process_group: bool = True,
372
):
373
try:
374
import torch._C
@@ -451,7 +452,7 @@ def __init__(
451
452
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
453
454
# skip validation if the process group was created above
- if process_group:
455
+ if process_group and should_validate_process_group:
456
validate_process_group(self.compute_device, self.process_group)
457
458
# enable pytorch sync_bn just in case model contains sync_bn layers.
0 commit comments