Skip to content

Commit 3f34441

Browse files
committed
Add main_grad
1 parent 45cd038 commit 3f34441

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17131713

17141714
# Switch to FP32 shard after backward.
17151715
self._use_fp32_param_shard([param])
1716+
if self.mixed_precision and self.fp32_reduce_scatter:
1717+
if getattr(param, "main_grad", None) is None:
1718+
param.main_grad = param.grad.to(torch.float32)
1719+
else:
1720+
param.main_grad.add_(param.grad.data)
1721+
1722+
param.grad = None
17161723

17171724
if not self._require_backward_grad_sync:
17181725
return
@@ -1721,23 +1728,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17211728
# reductions in post_backward stream.
17221729
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
17231730
with torch.cuda.stream(self._streams["post_backward"]):
1724-
orig_grad_data = param.grad.data
17251731

17261732
if self.fp32_reduce_scatter:
17271733
# Cast grad to FP32.
17281734
param.grad.data = param.grad.data.float()
17291735

1736+
orig_grad_data = param.grad.data
1737+
17301738
if self.gradient_predivide_factor > 1:
17311739
# Average grad by world_size for consistency with PyTorch DDP.
1732-
param.grad.data.div_(self.gradient_predivide_factor)
1740+
if getattr(param, "main_grad", None) is not None:
1741+
param.main_grad.data.div_(self.gradient_predivide_factor)
1742+
else:
1743+
param.grad.data.div_(self.gradient_predivide_factor)
17331744

17341745
if param._is_sharded:
17351746
assert self._reducer is not None
17361747
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
17371748
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
17381749
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
17391750
# matter, neglecting rounding.
1740-
grad = param.grad.data
1751+
if getattr(param, "main_grad", None) is not None:
1752+
grad = param.main_grad.data
1753+
param.main_grad = None
1754+
else:
1755+
grad = param.grad.data
17411756
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
17421757
#
17431758
# The effect on memory consumption is not usually significant. No extra memory is allocated if this

0 commit comments

Comments
 (0)