Skip to content

Commit f2bb56f

Browse files
committed
Avoid calling _free_fp16_param_shard() too early with PR 1159
1 parent a4f02ef commit f2bb56f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17331733
# Free full params.
17341734
self._free_full_params([param])
17351735

1736-
if self.mixed_precision:
1736+
if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward):
17371737
# This is a no-op if reshard_after_forward is True, since we already
17381738
# free the param shard when rebuilding the full params in the
17391739
# pre_backward_hook.
@@ -1861,7 +1861,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
18611861
def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None:
18621862
if self._should_free_in_backward():
18631863
self._free_full_params([param])
1864-
if self.mixed_precision:
1864+
if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward):
18651865
self._free_fp16_param_shard([param])
18661866
self._use_fp32_param_shard([param])
18671867

@@ -1937,7 +1937,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
19371937
# For the 1st layer, if the forward inputs did not require
19381938
# gradient, then we cannot run a reshard hook for it, and
19391939
# we instead free here.
1940-
if p._full_param_padded.untyped_storage().size() > 0:
1940+
if p._is_sharded and p._full_param_padded.untyped_storage().size() > 0:
19411941
fsdp_module._post_backward_reshard_hook(p)
19421942
continue
19431943

0 commit comments

Comments
 (0)