Skip to content

Commit d1bf367

Browse files
committed
Avoid calling _free_fp16_param_shard() too early
1 parent d0b506f commit d1bf367

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1706,7 +1706,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17061706
# bandwidth but uses more GPU memory.
17071707
self._free_full_params([param])
17081708

1709-
if self.mixed_precision:
1709+
if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward):
17101710
# This is a no-op if reshard_after_forward is True, since we already
17111711
# free the param shard when rebuilding the full params in the
17121712
# pre_backward_hook.

0 commit comments

Comments
 (0)