@@ -687,7 +687,7 @@ def _cast_buffers(
687
687
@property
688
688
def params_with_grad (self ) -> List [Parameter ]:
689
689
"""[p for p in self.parameters() if p.grad is not None]"""
690
- return [p for p in self .parameters () if p .grad is not None ]
690
+ return [p for p in self .parameters () if ( p .grad is not None or p . main_grad is not None ) ]
691
691
692
692
@torch .no_grad ()
693
693
def clip_grad_norm_ (
@@ -1714,30 +1714,48 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1714
1714
# Switch to FP32 shard after backward.
1715
1715
self ._use_fp32_param_shard ([param ])
1716
1716
1717
+ if self .fp32_reduce_scatter :
1718
+ if getattr (param , "unsharded_main_grad" , None ) is None :
1719
+ param .unsharded_main_grad = param .grad .to (torch .float32 )
1720
+ else :
1721
+ param .unsharded_main_grad .add_ (param .grad .data )
1722
+
1723
+ param .grad = None
1724
+
1717
1725
if not self ._require_backward_grad_sync :
1718
1726
return
1719
1727
1720
1728
# Wait for all work in the current stream to finish, then start the
1721
1729
# reductions in post_backward stream.
1722
1730
self ._streams ["post_backward" ].wait_stream (torch .cuda .current_stream ())
1723
1731
with torch .cuda .stream (self ._streams ["post_backward" ]):
1724
- orig_grad_data = param .grad .data
1725
1732
1726
1733
if self .fp32_reduce_scatter :
1727
1734
# Cast grad to FP32.
1728
- param .grad .data = param .grad .data .float ()
1735
+ orig_grad_data = param .unsharded_main_grad .data
1736
+ else :
1737
+ orig_grad_data = param .grad .data
1729
1738
1730
1739
if self .gradient_predivide_factor > 1 :
1731
1740
# Average grad by world_size for consistency with PyTorch DDP.
1732
- param .grad .data .div_ (self .gradient_predivide_factor )
1741
+ if getattr (param , "unsharded_main_grad" , None ) is not None :
1742
+ param .unsharded_main_grad .data .div_ (self .gradient_predivide_factor )
1743
+ else :
1744
+ param .grad .data .div_ (self .gradient_predivide_factor )
1733
1745
1734
1746
if param ._is_sharded :
1735
1747
assert self ._reducer is not None
1736
1748
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
1737
1749
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
1738
1750
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
1739
1751
# matter, neglecting rounding.
1740
- grad = param .grad .data
1752
+ if getattr (param , "unsharded_main_grad" , None ) is not None :
1753
+ grad = param .unsharded_main_grad .data
1754
+ param .unsharded_main_grad = None
1755
+ else :
1756
+ grad = param .grad .data
1757
+ param .grad = None
1758
+
1741
1759
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
1742
1760
#
1743
1761
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
@@ -1749,7 +1767,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1749
1767
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
1750
1768
# reduction for this module, before scheduling additional reduction work. Then at most there are two
1751
1769
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
1752
- param .grad = None
1753
1770
callback_fn = functools .partial (self ._post_reduction_hook , param )
1754
1771
self ._reducer .reduce_scatter_async (
1755
1772
grad , group = self .process_group_reduce_scatter , callback_fn = callback_fn
@@ -1759,7 +1776,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1759
1776
# world_size == 1. This could be relaxed in the future, in which
1760
1777
# case grads should be all-reduced here.
1761
1778
assert self .world_size == 1
1762
- self ._post_reduction_hook (param , param .grad )
1779
+ if getattr (param , "unsharded_main_grad" , None ) is not None :
1780
+ self ._post_reduction_hook (param , param .unsharded_main_grad )
1781
+ else :
1782
+ self ._post_reduction_hook (param , param .grad )
1763
1783
1764
1784
# After _post_backward_hook returns, orig_grad_data will eventually
1765
1785
# go out of scope, at which point it could otherwise be freed for
@@ -1785,7 +1805,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
1785
1805
# non-blocking. The downside is a bit more D2H transfer in that case.
1786
1806
if self .fp32_reduce_scatter :
1787
1807
orig_param_grad_data = reduced_grad .data
1788
- reduced_grad .data = reduced_grad .data .to (dtype = param .data .dtype )
1808
+ # reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
1789
1809
# Don't let this memory get reused until after the transfer.
1790
1810
orig_param_grad_data .record_stream (torch .cuda .current_stream ())
1791
1811
@@ -1799,6 +1819,8 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
1799
1819
), f"{ param ._saved_grad_shard .shape } vs { reduced_grad .shape } "
1800
1820
param ._saved_grad_shard .data += reduced_grad .data
1801
1821
reduced_grad = param ._saved_grad_shard .data
1822
+ elif (param .grad is None ) and self .fp32_reduce_scatter :
1823
+ param .main_grad = reduced_grad .data
1802
1824
1803
1825
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
1804
1826
# backwards pass completes, we will set `.grad` to the CPU copy.
@@ -1887,7 +1909,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
1887
1909
if p .shape != p ._saved_grad_shard .shape :
1888
1910
self ._use_fp32_param_shard ([p ])
1889
1911
if p ._saved_grad_shard .dtype != p .dtype :
1890
- p .grad = p ._saved_grad_shard . to ( p . dtype )
1912
+ p .main_grad = p ._saved_grad_shard
1891
1913
else :
1892
1914
p .grad = p ._saved_grad_shard
1893
1915
0 commit comments