@@ -1713,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1713
1713
1714
1714
# Switch to FP32 shard after backward.
1715
1715
self ._use_fp32_param_shard ([param ])
1716
+ if self .mixed_precision and self .fp32_reduce_scatter :
1717
+ if getattr (param , "main_grad" , None ) is None :
1718
+ param .main_grad = param .grad .to (torch .float32 )
1719
+ else :
1720
+ param .main_grad .add_ (param .grad .data )
1721
+
1722
+ param .grad = None
1716
1723
1717
1724
if not self ._require_backward_grad_sync :
1718
1725
return
@@ -1721,23 +1728,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1721
1728
# reductions in post_backward stream.
1722
1729
self ._streams ["post_backward" ].wait_stream (torch .cuda .current_stream ())
1723
1730
with torch .cuda .stream (self ._streams ["post_backward" ]):
1724
- orig_grad_data = param .grad .data
1725
1731
1726
1732
if self .fp32_reduce_scatter :
1727
1733
# Cast grad to FP32.
1728
1734
param .grad .data = param .grad .data .float ()
1729
1735
1736
+ orig_grad_data = param .grad .data
1737
+
1730
1738
if self .gradient_predivide_factor > 1 :
1731
1739
# Average grad by world_size for consistency with PyTorch DDP.
1732
- param .grad .data .div_ (self .gradient_predivide_factor )
1740
+ if getattr (param , "main_grad" , None ) is not None :
1741
+ param .main_grad .data .div_ (self .gradient_predivide_factor )
1742
+ else :
1743
+ param .grad .data .div_ (self .gradient_predivide_factor )
1733
1744
1734
1745
if param ._is_sharded :
1735
1746
assert self ._reducer is not None
1736
1747
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
1737
1748
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
1738
1749
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
1739
1750
# matter, neglecting rounding.
1740
- grad = param .grad .data
1751
+ if getattr (param , "main_grad" , None ) is not None :
1752
+ grad = param .main_grad .data
1753
+ param .main_grad = None
1754
+ else :
1755
+ grad = param .grad .data
1741
1756
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
1742
1757
#
1743
1758
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
0 commit comments