Skip to content

Commit 3b7cc24

Browse files
ngoyal2707Naman Goyalvedanujjiecaoyu
authored
[not to be merged yet] added temp changes for fp32 main grad, might not work for TE (#1151)
* added temp changes for fp32 main grad, might not work for TE * post rebase * changes to keep reduced grad in fp32 (#1152) * fix .grad=None issue when param is not sharded (#1153) * fixed broken clipping (#1154) Co-authored-by: Naman Goyal <[email protected]> --------- Co-authored-by: Naman Goyal <[email protected]> Co-authored-by: Vedanuj Goswami <[email protected]> Co-authored-by: Jiecao Yu <[email protected]>
1 parent a8189f0 commit 3b7cc24

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def _cast_buffers(
687687
@property
688688
def params_with_grad(self) -> List[Parameter]:
689689
"""[p for p in self.parameters() if p.grad is not None]"""
690-
return [p for p in self.parameters() if p.grad is not None]
690+
return [p for p in self.parameters() if (p.grad is not None or p.main_grad is not None)]
691691

692692
@torch.no_grad()
693693
def clip_grad_norm_(
@@ -1714,30 +1714,48 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17141714
# Switch to FP32 shard after backward.
17151715
self._use_fp32_param_shard([param])
17161716

1717+
if self.fp32_reduce_scatter:
1718+
if getattr(param, "unsharded_main_grad", None) is None:
1719+
param.unsharded_main_grad = param.grad.to(torch.float32)
1720+
else:
1721+
param.unsharded_main_grad.add_(param.grad.data)
1722+
1723+
param.grad = None
1724+
17171725
if not self._require_backward_grad_sync:
17181726
return
17191727

17201728
# Wait for all work in the current stream to finish, then start the
17211729
# reductions in post_backward stream.
17221730
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
17231731
with torch.cuda.stream(self._streams["post_backward"]):
1724-
orig_grad_data = param.grad.data
17251732

17261733
if self.fp32_reduce_scatter:
17271734
# Cast grad to FP32.
1728-
param.grad.data = param.grad.data.float()
1735+
orig_grad_data = param.unsharded_main_grad.data
1736+
else:
1737+
orig_grad_data = param.grad.data
17291738

17301739
if self.gradient_predivide_factor > 1:
17311740
# Average grad by world_size for consistency with PyTorch DDP.
1732-
param.grad.data.div_(self.gradient_predivide_factor)
1741+
if getattr(param, "unsharded_main_grad", None) is not None:
1742+
param.unsharded_main_grad.data.div_(self.gradient_predivide_factor)
1743+
else:
1744+
param.grad.data.div_(self.gradient_predivide_factor)
17331745

17341746
if param._is_sharded:
17351747
assert self._reducer is not None
17361748
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
17371749
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
17381750
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
17391751
# matter, neglecting rounding.
1740-
grad = param.grad.data
1752+
if getattr(param, "unsharded_main_grad", None) is not None:
1753+
grad = param.unsharded_main_grad.data
1754+
param.unsharded_main_grad = None
1755+
else:
1756+
grad = param.grad.data
1757+
param.grad = None
1758+
17411759
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
17421760
#
17431761
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
@@ -1749,7 +1767,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17491767
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
17501768
# reduction for this module, before scheduling additional reduction work. Then at most there are two
17511769
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
1752-
param.grad = None
17531770
callback_fn = functools.partial(self._post_reduction_hook, param)
17541771
self._reducer.reduce_scatter_async(
17551772
grad, group=self.process_group_reduce_scatter, callback_fn=callback_fn
@@ -1759,7 +1776,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17591776
# world_size == 1. This could be relaxed in the future, in which
17601777
# case grads should be all-reduced here.
17611778
assert self.world_size == 1
1762-
self._post_reduction_hook(param, param.grad)
1779+
if getattr(param, "unsharded_main_grad", None) is not None:
1780+
self._post_reduction_hook(param, param.unsharded_main_grad)
1781+
else:
1782+
self._post_reduction_hook(param, param.grad)
17631783

17641784
# After _post_backward_hook returns, orig_grad_data will eventually
17651785
# go out of scope, at which point it could otherwise be freed for
@@ -1785,7 +1805,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
17851805
# non-blocking. The downside is a bit more D2H transfer in that case.
17861806
if self.fp32_reduce_scatter:
17871807
orig_param_grad_data = reduced_grad.data
1788-
reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
1808+
# reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
17891809
# Don't let this memory get reused until after the transfer.
17901810
orig_param_grad_data.record_stream(torch.cuda.current_stream())
17911811

@@ -1799,6 +1819,8 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
17991819
), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}"
18001820
param._saved_grad_shard.data += reduced_grad.data
18011821
reduced_grad = param._saved_grad_shard.data
1822+
elif (param.grad is None) and self.fp32_reduce_scatter:
1823+
param.main_grad = reduced_grad.data
18021824

18031825
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
18041826
# backwards pass completes, we will set `.grad` to the CPU copy.
@@ -1887,7 +1909,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
18871909
if p.shape != p._saved_grad_shard.shape:
18881910
self._use_fp32_param_shard([p])
18891911
if p._saved_grad_shard.dtype != p.dtype:
1890-
p.grad = p._saved_grad_shard.to(p.dtype)
1912+
p.main_grad = p._saved_grad_shard
18911913
else:
18921914
p.grad = p._saved_grad_shard
18931915

0 commit comments

Comments
 (0)