-
Notifications
You must be signed in to change notification settings - Fork 87
/
Copy pathPipelineStage.py
754 lines (643 loc) · 25.7 KB
/
PipelineStage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import operator
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.fx as fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.nn.parallel import DistributedDataParallel
from pippy.backward import stage_backward
from pippy.debug import map_debug_info
from pippy.IR import Pipe
from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks
from pippy.utils import flatten_args, modify_graph_op_device
logger = logging.getLogger(__name__)
def _make_tensor_from_meta(
example_value: FakeTensor,
device: torch.device,
) -> torch.Tensor:
return torch.empty(
example_value.size(), dtype=example_value.dtype, device=device
)
class RecvInfo:
def __init__(
self,
input_name: str,
source: int,
buffer: torch.Tensor,
):
self.input_name = input_name
self.source = source
self.buffer = buffer
def __repr__(self):
return f"RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})"
class StageArgPlaceholder:
pass
class PipelineStage(torch.nn.Module):
def __init__(
self,
pipe: Pipe,
stage_index: int,
device: torch.device,
group: dist.ProcessGroup = None,
):
super().__init__()
self.pipe = pipe
self.stage_index = stage_index
self.nstages = pipe.num_stages
self.chunks = pipe.num_chunks
self.device = device
self.group = group
if dist.get_world_size(self.group) > self.nstages:
raise RuntimeError(
"Number of ranks is larger than number of stages, some ranks are unused"
)
# `group_rank` is rank in process group `group`.
self.group_rank = dist.get_rank(group)
# Run time states
# map microbatch ID to list of forward tensor args
self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {}
# Split input chunks
self.args_split = None
self.kwargs_split = None
# Activation send requests of all chunk
self.all_act_send_reqs: List[dist.Work] = []
# Grad send requests of all chunk
self.all_grad_send_reqs: List[dist.Work] = []
# Caching chunk outputs for final output merge or reduction
self.output_chunks: List[Any] = []
# Find my submodule
self.split_gm = self.pipe.split_gm
named_children = list(self.split_gm.named_children())
self.name, self.submod = named_children[stage_index]
logger.info(
f"[{self.group_rank}] "
f"Creating PipelineStage:\n"
f"{self.submod}"
)
# Find my forward node in graph
found_node = False
for node in self.split_gm.graph.nodes:
if node.name == self.name:
self.node = node
found_node = True
break
if not found_node:
raise AssertionError(f"Cannot find {self.name} in graph")
# Find my backward node in graph
if self.pipe.has_loss_and_backwards:
found_bwd = False
seen_bwd = -1
for node in reversed(self.split_gm.graph.nodes):
if (node.op, node.target) == ("call_function", stage_backward):
seen_bwd += 1
if seen_bwd == self.stage_index:
found_bwd = True
self.bwd_node = node
break
if not found_bwd:
raise AssertionError(
f"Cannot find backward for {self.name} in graph"
)
# Create submod to rank mapping
self.submod_to_stage_index: Dict[str, int] = {}
for i, (name, _) in enumerate(self.split_gm.named_children()):
self.submod_to_stage_index.setdefault(name, i)
# Create stage id to group rank mapping
# In interleaved case, `group_rank` is stage index % group size.
self.stage_index_to_group_rank: Dict[int, int] = {}
pg_world_size = dist.get_world_size(group)
for i in range(self.nstages):
# We only support wrapped-around interleaving
peer_rank = i % pg_world_size
self.stage_index_to_group_rank.setdefault(i, peer_rank)
# Prepare send/recv infrastructure
self._prepare_send_recv_infra()
# Cast submodule to device
self._move_submod_to_device()
# Move ops argument to device
self._move_ops_to_device()
def _move_submod_to_device(self):
# Move submodule to indicated device if possible
# Note: we cannot move meta module to real devices because meta tensors
# do not support to() method. One needs to do an in-place tensor swap in
# that case.
has_meta_param = any(
isinstance(p, FakeTensor) or p.is_meta
for p in self.submod.parameters()
)
if has_meta_param:
logger.debug(f"[{self.group_rank}] Found meta parameters!")
else:
logger.debug(f"[{self.group_rank}] No meta parameters found!")
self.submod.to(self.device)
def _move_ops_to_device(self):
# Today PT2 tracer does not treat `x.device` as a symbolic device;
# instead, the device of tracing time got burned into the generated
# code. Here we provide a workaround for users to manually modify the
# "device" kwarg of operations. Such operation may include:
# `torch.ones`, `torch.zeros`, `torch.rand`, etc.
modify_graph_op_device(self.submod, self.device)
def is_first(self):
return self.stage_index == 0
def is_last(self):
return self.stage_index == self.nstages - 1
def _prepare_send_recv_infra(self):
"""
Create send/recv infrastructures for activations (during forward) and
gradients (during backward)
"""
# chunk : Tuple of arg buffers
self.args_recv_info: Dict[int, Tuple] = {}
# chunk : Dict of kwarg buffers
self.kwargs_recv_info: Dict[int, Dict] = {}
for chunk in range(self.chunks):
(
self.args_recv_info[chunk],
self.kwargs_recv_info[chunk],
) = self._create_act_recv_buffers()
# Send info during forward for each activation
self.act_send_info = self._create_act_send_info()
if self.pipe.has_loss_and_backwards:
# chunk : List of output grad buffers
# `grad_recv_info` is a mirror of `act_send_info`
self.grad_recv_info: Dict = {}
for chunk in range(self.chunks):
self.grad_recv_info[chunk] = self._create_grad_recv_info(
self.act_send_info
)
# Send info for input grads during backward
# List of destinations corresponding to input grads
# Can be None if an input has no grad
# `grad_send_info` is a mirror of `args_recv_info` + `kwargs_recv_info`
self.grad_send_info = self._create_grad_send_info(
self.args_recv_info[0],
self.kwargs_recv_info[0],
)
def get_stage_index_of_submod(
self,
submod_name: str,
):
if submod_name not in self.submod_to_stage_index:
raise AssertionError(f"Stage id of {submod_name} not found")
return self.submod_to_stage_index[submod_name]
def _create_act_recv_buffers(
self,
):
def create_recv_tensor(
input_node,
output_idx: Optional[int] = None,
):
"""
Create a tensor for receiving the `output_idx`-th value from
`input_node`
"""
if input_node.op == "placeholder":
# Do not create buffer for placeholder
return StageArgPlaceholder()
# In case the input is a `getitem` node, we recursively find the
# real source e.g. getitem1 = submod0[1]
# Here `submod0` is args[0], 1 is args[1]
if input_node.target is operator.getitem:
if "example_value" in input_node.meta:
real_input_node = input_node.args[0]
out_idx = input_node.args[1]
return create_recv_tensor(real_input_node, out_idx)
else:
raise NotImplementedError(
f"getitem gets a non-Tensor value, this is not yet supported. "
f"Node: {input_node.format_node()}"
)
if output_idx is not None:
# If a node has multiple output values, "example_value" is a list
# of tensor meta
example_value = input_node.meta["example_value"][output_idx]
else:
example_value = input_node.meta["example_value"]
logger.info(
f"[{self.group_rank}] "
f"Creating recv buffer for input '{input_node.name}' "
f"value index {output_idx}: {example_value.size()}"
)
src_rank = self.get_stage_index_of_submod(input_node.name)
buffer = _make_tensor_from_meta(example_value, self.device)
# Enable gradient in training mode
if self.pipe.has_loss_and_backwards:
buffer.requires_grad_(True)
return RecvInfo(
input_node.name,
src_rank,
buffer,
)
# `args` is a Tuple, hence we will have:
# Tuple[RecvInfo]
args_recv_info = fx.node.map_arg(self.node.args, create_recv_tensor)
# `kwargs` is a Dict, hence we will have:
# Dict[keyword, RecvInfo]
kwargs_recv_info = fx.node.map_arg(self.node.kwargs, create_recv_tensor)
logger.info(
f"[{self.group_rank}] " f"Activation recv info: {args_recv_info}"
)
return args_recv_info, kwargs_recv_info
def find_dst_rank(
self,
user: fx.Node,
) -> Optional[int]:
"""
Find the destination rank of a `user` node.
If the `user` is not a submod, `None` may be returned.
"""
if user.op == "call_module":
# User is a stage (`call_module`)
return self.get_stage_index_of_submod(user.name)
else:
# - If user.op == "output":
# No need to send back to rank 0
# - If user.target is stage_backward:
# No need to send assuming submod output is stored locally or
# should be re-calucated in case of activation checkpointing
return None
def _create_act_send_info(self):
# Output index: List of receiver ranks
act_send_info: Dict[int, List] = {}
out_idx = 0
for user in self.node.users:
if user.target is operator.getitem:
# Recursively find the real destination
gi_dsts = act_send_info.setdefault(out_idx, [])
for gi_user in user.users:
dst_rank = self.find_dst_rank(gi_user)
if dst_rank is not None:
gi_dsts.append(dst_rank)
# Next `getitem` will point to the next output index
out_idx += 1
else:
# In case of single output value, `out_idx` will not increase
dsts = act_send_info.setdefault(out_idx, [])
dst_rank = self.find_dst_rank(user)
if dst_rank is not None:
dsts.append(dst_rank)
logger.info(f"[{self.group_rank}] " f"Send info: {act_send_info}")
return act_send_info
def _create_grad_recv_info(
self,
act_send_info: Dict,
) -> Dict[int, RecvInfo]:
# Dict[output_index, RecvInfo]
grad_recv_info: Dict = {}
my_example_value = self.node.meta["example_value"]
for out_idx, dst_list in act_send_info.items():
if not dst_list:
# No actual receiver for activation so no grad coming back
continue
# TODO: clean way
if len(act_send_info) > 1:
example_value = my_example_value[out_idx]
else:
example_value = my_example_value
# TODO: otherwise needs grad accumulation
assert len(dst_list) == 1
grad_src = dst_list[0]
grad_recv_info[out_idx] = RecvInfo(
f"{grad_src}",
grad_src,
_make_tensor_from_meta(example_value, self.device),
)
logger.info(f"[{self.group_rank}] " f"Grad recv info: {grad_recv_info}")
return grad_recv_info
def _create_grad_send_info(
self,
args_recv_info: Tuple,
kwargs_recv_info: Dict,
) -> List[Optional[int]]:
grad_send_info: List[Optional[int]] = []
def map_recv_to_send(a):
if isinstance(a, RecvInfo):
grad_send_info.append(a.source)
return a.source
else:
grad_send_info.append(None)
return None
fx.node.map_aggregate(args_recv_info, map_recv_to_send)
fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send)
logger.info(f"[{self.group_rank}] " f"Grad send info: {grad_send_info}")
return grad_send_info
def _recv_tensor(self, info, recv_reqs):
logger.debug(
f"[{self.group_rank}] "
f"Receiving tensor '{info.input_name}' from Rank {info.source}: "
f"{info.buffer.size()}"
)
# Use async to parallelize recv of tensors
peer_rank = self.stage_index_to_group_rank[info.source]
work = dist.irecv(
info.buffer,
peer_rank
if self.group is None
else dist.get_global_rank(self.group, peer_rank),
group=self.group,
)
recv_reqs.append(work)
return info.buffer
def recv_tensor_fn(
self,
reqs,
):
return lambda info: self._recv_tensor(info, reqs)
def split_inputs(self, args, kwargs):
self.args_split = None
self.kwargs_split = None
if args or kwargs:
self.args_split, self.kwargs_split = split_args_kwargs_into_chunks(
args,
kwargs,
self.chunks,
self.pipe.args_chunk_spec,
self.pipe.kwargs_chunk_spec,
)
def _recv_and_fill_inputs(
self,
chunk: int,
):
# Receive requests of a chunk
recv_reqs: List[dist.Work] = []
act_recv = self.recv_tensor_fn(recv_reqs)
if self.args_split:
chunk_args = self.args_split[chunk]
chunk_args_list = list(chunk_args)
def recv_args(info):
if isinstance(info, RecvInfo):
return act_recv(info)
else:
return chunk_args_list.pop(0) # type: ignore[has-type]
composite_args = fx.node.map_aggregate(
self.args_recv_info[chunk],
recv_args,
)
if self.kwargs_split:
chunk_kwargs = self.kwargs_split[chunk]
def recv_kwargs(info):
if isinstance(info, RecvInfo):
return act_recv(info)
else:
k = next(iter(chunk_kwargs)) # type: ignore[has-type]
return chunk_kwargs.pop(k) # type: ignore[has-type]
composite_kwargs = fx.node.map_aggregate(
self.kwargs_recv_info[chunk],
recv_kwargs,
)
# Wait for all recvs to finish
for work in recv_reqs:
work.wait()
return composite_args, composite_kwargs
def _send_activations(
self,
output_tuple,
) -> List[dist.Work]:
# Send requests of a chunk
send_reqs: List[dist.Work] = []
for idx, out in enumerate(output_tuple):
dst_stages = self.act_send_info[idx]
for dst in dst_stages:
if dst is None:
continue
logger.debug(
f"[{self.group_rank}] "
f"Sending tensor to Rank {dst}: {out.size()}"
)
peer_rank = self.stage_index_to_group_rank[dst]
work = dist.isend(
# HACK: we convert DTensor to regular tensor here for it to
# work with send ops. DTensor may show up in PP + TP cases.
out.to_local()
if isinstance(out, torch.distributed._tensor.DTensor)
else out,
peer_rank
if self.group is None
else dist.get_global_rank(self.group, peer_rank), # TODO
group=self.group,
)
send_reqs.append(work)
return send_reqs
def _recv_grads(
self,
bwd_chunk,
):
# Receive requests of a chunk
grad_recv_reqs: List[dist.Work] = []
recv_grad = self.recv_tensor_fn(grad_recv_reqs)
# Receive gradients
grads = fx.node.map_aggregate(
self.grad_recv_info[bwd_chunk],
recv_grad,
)
# Wait for all recvs to finish
for work in grad_recv_reqs:
work.wait()
logger.debug(
f"[{self.group_rank}] "
f"Received output grads of chunk {bwd_chunk}: {map_debug_info(grads)}"
)
return grads
def _send_grads(
self,
grads_input,
) -> List[dist.Work]:
# Send requests of a chunk
grad_send_reqs: List[dist.Work] = []
for grad, grad_recv_stage in zip(grads_input, self.grad_send_info):
if isinstance(grad, torch.Tensor) and grad_recv_stage is not None:
logger.debug(
f"[{self.group_rank}] "
f"Sending gradient to Rank {grad_recv_stage}: {grad.size()}"
)
peer_rank = self.stage_index_to_group_rank[grad_recv_stage]
work = dist.isend(
grad,
peer_rank
if self.group is None
else dist.get_global_rank(self.group, peer_rank), # TODO
group=self.group,
)
grad_send_reqs.append(work)
else:
assert grad is None and grad_recv_stage is None
return grad_send_reqs
def forward_maybe_with_nosync(self, *args, **kwargs):
# If submod is wrapped with DDP, we use the `no_sync` context manager to
# avoid gradient all-reduce per microbatch
if isinstance(self.submod, DistributedDataParallel):
with self.submod.no_sync(): # type: ignore[operator]
out_val = self.submod(*args, **kwargs)
else:
out_val = self.submod(*args, **kwargs)
return out_val
def backward_maybe_with_nosync(self, bwd_kwargs: Dict, is_last_chunk: bool):
if isinstance(self.submod, DistributedDataParallel):
if is_last_chunk:
# HACK: reaching into DDP implementation details here. Is there a better way?
self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator]
list(
torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined]
bwd_kwargs["stage_output"]
)
)
)
grads_input, _ = stage_backward(**bwd_kwargs)
else:
with self.submod.no_sync(): # type: ignore[operator]
grads_input, _ = stage_backward(**bwd_kwargs)
else:
# Non-DDP submodule, regular backward
grads_input, _ = stage_backward(**bwd_kwargs)
return grads_input
def forward_one_chunk(
self,
chunk: int,
):
composite_args, composite_kwargs = self._recv_and_fill_inputs(chunk)
# Compute forward
try:
output = self.forward_maybe_with_nosync(
*composite_args, **composite_kwargs
)
except Exception as e:
exc_msg = f"""
Rank {self.group_rank} failed to run forward stage: {self.name}
args: {map_debug_info(composite_args)}
kwargs: {map_debug_info(composite_kwargs)}
"""
raise RuntimeError(exc_msg) from e
if type(output) is list:
# HACK: this is a hacky workaround for the fact that export creates
# output in list format
output = tuple(output)
logger.debug(map_debug_info(output))
# Unify output form to tuple for easy correspondance with
# `act_send_info`
output_tuple = output if type(output) is tuple else (output,)
# Prepare for final output merge or reduction
self.output_chunks.append(output)
# Send activations
send_reqs = self._send_activations(output_tuple)
self.all_act_send_reqs += send_reqs
# Save activations and inputs for backward
flat_args = flatten_args(composite_args)
flat_kwargs = flatten_args(composite_kwargs)
flatten_input_tensors = flat_args + flat_kwargs
self.fwd_cache[chunk] = (
output_tuple, # stage_output
flatten_input_tensors, # input_values
)
def backward_one_chunk(
self,
bwd_chunk: int,
):
if not self.pipe.has_loss_and_backwards:
return None
grads = self._recv_grads(bwd_chunk)
# Pack args for `stage_backward``
bwd_kwargs = dict(self.bwd_node.kwargs)
(
bwd_kwargs["stage_output"],
bwd_kwargs["input_values"],
) = self.fwd_cache.pop(bwd_chunk)
# Fill actual gradients received for outputs
# If nothing received, as in the case of last stage, then we
# would use the default `output_grads` prepared in the IR phase,
# i.e. from `bwd_node.kwargs`. For example, it may look like
# this if there are two outputs: ('None', 'None')
if len(grads) > 0:
bwd_kwargs["output_grads"] = grads
# `stage_backward` node does not have `args`, only `kwargs`
grads_input = self.backward_maybe_with_nosync(
bwd_kwargs,
bwd_chunk == self.chunks - 1,
)
grad_send_reqs = self._send_grads(grads_input)
self.all_grad_send_reqs += grad_send_reqs
def clear_runtime_states(self):
# map microbatch ID to list of forward tensor args
self.fwd_cache.clear()
# Activation send requests of all chunk
self.all_act_send_reqs.clear()
# Grad send requests of all chunk
self.all_grad_send_reqs.clear()
# Caching chunk outputs for final output merge or reduction
self.output_chunks.clear()
def merge_output_chunks(self):
return merge_chunks(
self.output_chunks,
self.pipe.output_chunk_spec,
)
def forward(self, *args, **kwargs):
# Clean per iteration
self.clear_runtime_states()
# Split inputs into chunks
self.split_inputs(args, kwargs)
# Forward pass of all chunks
for chunk in range(self.chunks):
self.forward_one_chunk(chunk)
logger.debug(f"[{self.group_rank}] Forwarded chunk {chunk}")
# Backward starts here
for bwd_chunk in range(self.chunks):
self.backward_one_chunk(bwd_chunk)
logger.debug(f"[{self.group_rank}] Backwarded chunk {bwd_chunk}")
# Wait for all sends to finish
# TODO: okay to delay the sync till completion of all chunks?
for work in self.all_act_send_reqs:
work.wait()
# Wait for all sends to finish
# TODO: okay to delay the sync till completion of all chunks?
for work in self.all_grad_send_reqs:
work.wait()
# Last rank return merged results per original format
if self.is_last():
return self.merge_output_chunks()
else:
return None
class PipelineStage1F1B(PipelineStage):
def __init__(
self,
pipe: Pipe,
rank: int,
device: torch.device,
group: dist.ProcessGroup = None,
):
super().__init__(
pipe,
rank,
device,
group=group,
)
def forward(self, *args, **kwargs):
# Clean per iteration
self.clear_runtime_states()
# Split inputs into chunks
self.split_inputs(args, kwargs)
warmup_chunks = cooldown_chunks = self.nstages
# Warm-up phase: forward number of chunks equal to pipeline depth.
for chunk in range(warmup_chunks):
self.forward_one_chunk(chunk)
# 1F1B phase
for bwd_chunk in range(0, self.chunks - cooldown_chunks):
# Schedule backward for one warmed up chunk
self.backward_one_chunk(bwd_chunk)
# Schedule forward for one new chunk
fwd_chunk = bwd_chunk + warmup_chunks
self.forward_one_chunk(fwd_chunk)
# Cool-down phase: backward for the rest of the chunks
for bwd_chunk in range(self.chunks - cooldown_chunks, self.chunks):
self.backward_one_chunk(bwd_chunk)
# Wait for all sends to finish
# TODO: okay to delay the sync till completion of all chunks?
for work in self.all_act_send_reqs:
work.wait()
for work in self.all_grad_send_reqs:
work.wait()
# Last rank return merged results per original format
if self.is_last():
return self.merge_output_chunks()
else:
return None