Skip to content

Commit 74a7313

Browse files
committed
Support for only performing norm weights allreduce in last microbatch | fairscale
1 parent d0b506f commit 74a7313

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def __init__(
369369
limit_all_gather_events: bool = False,
370370
limit_reduce_scatter_events: bool = False,
371371
should_validate_process_group: bool = True,
372+
tensor_parallel_group: Optional[ProcessGroup] = None,
372373
):
373374
try:
374375
import torch._C
@@ -380,6 +381,7 @@ def __init__(
380381
init_start = time.time()
381382
super().__init__()
382383
self.process_group = process_group or get_process_group_cached()
384+
self.tensor_parallel_group = tensor_parallel_group
383385
# If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with
384386
# the rest of operations. The overlap feature in the backward propagation is disabled.
385387
if process_group_reduce_scatter == ProcessGroupName.default:
@@ -1726,6 +1728,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17261728
if not self._require_backward_grad_sync:
17271729
return
17281730

1731+
# run allreduce on param if necessary
1732+
if self.tensor_parallel_group.size() > 1:
1733+
if self.fp32_reduce_scatter:
1734+
orig_grad_data = param.unsharded_main_grad.data
1735+
else:
1736+
orig_grad_data = param.grad.data
1737+
for idx_pair in param._param_require_tp_allreduce:
1738+
param_allreduce = orig_grad_data[idx_pair[0]:idx_pair[1]].contiguous()
1739+
torch.distributed.all_reduce(param_allreduce, group=self.tensor_parallel_group)
1740+
orig_grad_data[idx_pair[0]:idx_pair[1]].copy_(param_allreduce)
1741+
17291742
# Wait for all work in the current stream to finish, then start the
17301743
# reductions in post_backward stream.
17311744
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())

fairscale/nn/misc/flatten_params_wrapper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) ->
7070
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
7171
"""Initialize the _param_numels and _param_shapes lists."""
7272
self._param_numels = [p.numel() for p in params]
73+
self._param_require_tp_allreduce = []
74+
for idx in range(len(params)):
75+
p = params[idx]
76+
if hasattr(p, "norm_allreduce_last_microbatch") and p.norm_allreduce_last_microbatch:
77+
self._param_require_tp_allreduce.append(
78+
[sum(self._param_numels[0:idx]), sum(self._param_numels[0:idx+1])]
79+
)
7380
assert self.numel() <= sum(
7481
self._param_numels
7582
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"

0 commit comments

Comments
 (0)