Skip to content

Commit 0f2229e

Browse files
committed
move changes after orig_grad_data
1 parent 60fa4f0 commit 0f2229e

File tree

1 file changed

+16
-22
lines changed

1 file changed

+16
-22
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,39 +1714,33 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17141714
# Switch to FP32 shard after backward.
17151715
self._use_fp32_param_shard([param])
17161716

1717-
if self.fp32_reduce_scatter:
1718-
if param.grad is not None:
1719-
if param.main_grad is not None:
1720-
param.main_grad.add_(param.grad.data.float())
1721-
else:
1722-
param.main_grad = param.grad.data.float()
1723-
param.grad = None
1724-
17251717
if not self._require_backward_grad_sync:
17261718
return
17271719

17281720
# Wait for all work in the current stream to finish, then start the
17291721
# reductions in post_backward stream.
17301722
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
17311723
with torch.cuda.stream(self._streams["post_backward"]):
1732-
# orig_grad_data = param.main_grad.data
1724+
if param.main_grad is not None:
1725+
orig_grad_data = param.main_grad
1726+
else:
1727+
orig_grad_data = param.grad
17331728

17341729
if self.fp32_reduce_scatter:
1735-
# Cast grad to FP32. with .main_grad params are already in FP32.
1736-
if param.main_grad is not None:
1737-
orig_grad_data = param.main_grad.data
1738-
else:
1739-
orig_grad_data = param.grad.data.to(torch.float32)
1740-
else:
1741-
orig_grad_data = param.grad.data
1730+
if param.grad is not None:
1731+
if param.main_grad is not None:
1732+
param.main_grad.copy_(param.grad.float())
1733+
else:
1734+
param.main_grad = param.grad.float()
1735+
param.grad = None
17421736

17431737
if self.gradient_predivide_factor > 1:
17441738
# Average grad by world_size for consistency with PyTorch DDP.
17451739
# param.grad.data.div_(self.gradient_predivide_factor)
17461740
if param.main_grad is not None:
1747-
param.main_grad.data.div_(self.gradient_predivide_factor)
1741+
param.main_grad.div_(self.gradient_predivide_factor)
17481742
else:
1749-
param.grad.data.div_(self.gradient_predivide_factor)
1743+
param.grad.div_(self.gradient_predivide_factor)
17501744

17511745
if param._is_sharded:
17521746
assert self._reducer is not None
@@ -1755,10 +1749,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17551749
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
17561750
# matter, neglecting rounding.
17571751
if param.main_grad is not None:
1758-
grad = param.main_grad.data
1752+
grad = param.main_grad
17591753
param.main_grad = None
17601754
else:
1761-
grad = param.grad.data
1755+
grad = param.grad
17621756
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
17631757
#
17641758
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
@@ -1781,9 +1775,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17811775
# case grads should be all-reduced here.
17821776
assert self.world_size == 1
17831777
if param.main_grad is not None:
1784-
self._post_reduction_hook(param, param.main_grad.data)
1778+
self._post_reduction_hook(param, param.main_grad)
17851779
else:
1786-
self._post_reduction_hook(param, param.grad.data)
1780+
self._post_reduction_hook(param, param.grad)
17871781

17881782
# After _post_backward_hook returns, orig_grad_data will eventually
17891783
# go out of scope, at which point it could otherwise be freed for

0 commit comments

Comments
 (0)