@@ -1714,39 +1714,33 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1714
1714
# Switch to FP32 shard after backward.
1715
1715
self ._use_fp32_param_shard ([param ])
1716
1716
1717
- if self .fp32_reduce_scatter :
1718
- if param .grad is not None :
1719
- if param .main_grad is not None :
1720
- param .main_grad .add_ (param .grad .data .float ())
1721
- else :
1722
- param .main_grad = param .grad .data .float ()
1723
- param .grad = None
1724
-
1725
1717
if not self ._require_backward_grad_sync :
1726
1718
return
1727
1719
1728
1720
# Wait for all work in the current stream to finish, then start the
1729
1721
# reductions in post_backward stream.
1730
1722
self ._streams ["post_backward" ].wait_stream (torch .cuda .current_stream ())
1731
1723
with torch .cuda .stream (self ._streams ["post_backward" ]):
1732
- # orig_grad_data = param.main_grad.data
1724
+ if param .main_grad is not None :
1725
+ orig_grad_data = param .main_grad
1726
+ else :
1727
+ orig_grad_data = param .grad
1733
1728
1734
1729
if self .fp32_reduce_scatter :
1735
- # Cast grad to FP32. with .main_grad params are already in FP32.
1736
- if param .main_grad is not None :
1737
- orig_grad_data = param .main_grad .data
1738
- else :
1739
- orig_grad_data = param .grad .data .to (torch .float32 )
1740
- else :
1741
- orig_grad_data = param .grad .data
1730
+ if param .grad is not None :
1731
+ if param .main_grad is not None :
1732
+ param .main_grad .copy_ (param .grad .float ())
1733
+ else :
1734
+ param .main_grad = param .grad .float ()
1735
+ param .grad = None
1742
1736
1743
1737
if self .gradient_predivide_factor > 1 :
1744
1738
# Average grad by world_size for consistency with PyTorch DDP.
1745
1739
# param.grad.data.div_(self.gradient_predivide_factor)
1746
1740
if param .main_grad is not None :
1747
- param .main_grad .data . div_ (self .gradient_predivide_factor )
1741
+ param .main_grad .div_ (self .gradient_predivide_factor )
1748
1742
else :
1749
- param .grad .data . div_ (self .gradient_predivide_factor )
1743
+ param .grad .div_ (self .gradient_predivide_factor )
1750
1744
1751
1745
if param ._is_sharded :
1752
1746
assert self ._reducer is not None
@@ -1755,10 +1749,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1755
1749
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
1756
1750
# matter, neglecting rounding.
1757
1751
if param .main_grad is not None :
1758
- grad = param .main_grad . data
1752
+ grad = param .main_grad
1759
1753
param .main_grad = None
1760
1754
else :
1761
- grad = param .grad . data
1755
+ grad = param .grad
1762
1756
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
1763
1757
#
1764
1758
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
@@ -1781,9 +1775,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1781
1775
# case grads should be all-reduced here.
1782
1776
assert self .world_size == 1
1783
1777
if param .main_grad is not None :
1784
- self ._post_reduction_hook (param , param .main_grad . data )
1778
+ self ._post_reduction_hook (param , param .main_grad )
1785
1779
else :
1786
- self ._post_reduction_hook (param , param .grad . data )
1780
+ self ._post_reduction_hook (param , param .grad )
1787
1781
1788
1782
# After _post_backward_hook returns, orig_grad_data will eventually
1789
1783
# go out of scope, at which point it could otherwise be freed for
0 commit comments