@@ -369,6 +369,7 @@ def __init__(
369
369
limit_all_gather_events : bool = False ,
370
370
limit_reduce_scatter_events : bool = False ,
371
371
should_validate_process_group : bool = True ,
372
+ tensor_parallel_group : Optional [ProcessGroup ] = None ,
372
373
):
373
374
try :
374
375
import torch ._C
@@ -380,6 +381,7 @@ def __init__(
380
381
init_start = time .time ()
381
382
super ().__init__ ()
382
383
self .process_group = process_group or get_process_group_cached ()
384
+ self .tensor_parallel_group = tensor_parallel_group
383
385
# If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with
384
386
# the rest of operations. The overlap feature in the backward propagation is disabled.
385
387
if process_group_reduce_scatter == ProcessGroupName .default :
@@ -1726,6 +1728,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
1726
1728
if not self ._require_backward_grad_sync :
1727
1729
return
1728
1730
1731
+ # run allreduce on param if necessary
1732
+ if self .tensor_parallel_group .size () > 1 :
1733
+ if self .fp32_reduce_scatter :
1734
+ orig_grad_data = param .unsharded_main_grad .data
1735
+ else :
1736
+ orig_grad_data = param .grad .data
1737
+ for idx_pair in param ._param_require_tp_allreduce :
1738
+ param_allreduce = orig_grad_data [idx_pair [0 ]:idx_pair [1 ]].contiguous ()
1739
+ torch .distributed .all_reduce (param_allreduce , group = self .tensor_parallel_group )
1740
+ orig_grad_data [idx_pair [0 ]:idx_pair [1 ]].copy_ (param_allreduce )
1741
+
1729
1742
# Wait for all work in the current stream to finish, then start the
1730
1743
# reductions in post_backward stream.
1731
1744
self ._streams ["post_backward" ].wait_stream (torch .cuda .current_stream ())
0 commit comments