22
22
from typing import Any
23
23
24
24
import torch
25
+ from torch .autograd .graph import saved_tensors_hooks
25
26
26
- from .initialize import get_model_parallel_group
27
+ from .initialize import get_model_parallel_group , get_model_parallel_world_size , get_model_parallel_rank
27
28
from .utils import split_tensor_along_last_dim
28
29
29
30
@@ -154,3 +155,44 @@ def scatter_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
154
155
155
156
def gather_from_model_parallel_region (input_ : torch .Tensor ) -> torch .Tensor :
156
157
return _GatherFromModelParallelRegion .apply (input_ )
158
+
159
+
160
+ def _pack_over_mp (tensor ):
161
+ mp_world_size = get_model_parallel_world_size ()
162
+ if mp_world_size == 1 :
163
+ return tensor # no-op for mp=1
164
+ full_tensor_shape = list (tensor .shape )
165
+ shard = tensor .view (- 1 ).chunk (mp_world_size , dim = 0 )[get_model_parallel_rank ()]
166
+ shard = shard .detach ().clone ().contiguous () # clone to explicitly release memory of the full tensor
167
+ del tensor
168
+ return shard , full_tensor_shape
169
+
170
+
171
+ def _unpack_over_mp (sharded_tensor ):
172
+ sharded_tensor , full_tensor_shape = sharded_tensor
173
+ mp_world_size = get_model_parallel_world_size ()
174
+ if mp_world_size == 1 :
175
+ return sharded_tensor # no-op for mp=1
176
+ full_tensor = torch .empty (
177
+ * full_tensor_shape ,
178
+ dtype = sharded_tensor .dtype ,
179
+ device = sharded_tensor .device )
180
+
181
+ torch .distributed .all_gather_into_tensor (
182
+ full_tensor .view (- 1 ), sharded_tensor , group = get_model_parallel_group ()
183
+ )
184
+
185
+ return full_tensor
186
+
187
+
188
+ class shard_over_mp_group (saved_tensors_hooks ):
189
+ """Context manager for activatoin sharding.
190
+
191
+ This context manager shard tensors saved by autograd over the
192
+ model parallel group in the forward pass and unshards them
193
+ in the backward pass. Useful to remove redundancy in the
194
+ long-living activation tensors.
195
+ """
196
+
197
+ def __init__ (self ):
198
+ super ().__init__ (_pack_over_mp , _unpack_over_mp )
0 commit comments