@@ -1733,7 +1733,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1733
1733
# Free full params.
1734
1734
self ._free_full_params ([param ])
1735
1735
1736
- if self .mixed_precision :
1736
+ if self .mixed_precision and ( self . _require_backward_grad_sync or self . reshard_after_forward ) :
1737
1737
# This is a no-op if reshard_after_forward is True, since we already
1738
1738
# free the param shard when rebuilding the full params in the
1739
1739
# pre_backward_hook.
@@ -1861,7 +1861,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
1861
1861
def _post_backward_reshard_hook (self , param : Parameter , * unused : Any ) -> None :
1862
1862
if self ._should_free_in_backward ():
1863
1863
self ._free_full_params ([param ])
1864
- if self .mixed_precision :
1864
+ if self .mixed_precision and ( self . _require_backward_grad_sync or self . reshard_after_forward ) :
1865
1865
self ._free_fp16_param_shard ([param ])
1866
1866
self ._use_fp32_param_shard ([param ])
1867
1867
@@ -1937,7 +1937,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
1937
1937
# For the 1st layer, if the forward inputs did not require
1938
1938
# gradient, then we cannot run a reshard hook for it, and
1939
1939
# we instead free here.
1940
- if p ._full_param_padded .untyped_storage ().size () > 0 :
1940
+ if p ._is_sharded and p . _full_param_padded .untyped_storage ().size () > 0 :
1941
1941
fsdp_module ._post_backward_reshard_hook (p )
1942
1942
continue
1943
1943
0 commit comments