Skip to content

Commit 71d718b

Browse files
committed
Add a context manager for activation sharding.
1 parent 164cc0f commit 71d718b

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

fairscale/nn/model_parallel/mappings.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
from typing import Any
2323

2424
import torch
25+
from torch.autograd.graph import saved_tensors_hooks
2526

26-
from .initialize import get_model_parallel_group
27+
from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
2728
from .utils import split_tensor_along_last_dim
2829

2930

@@ -154,3 +155,44 @@ def scatter_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
154155

155156
def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
156157
return _GatherFromModelParallelRegion.apply(input_)
158+
159+
160+
def _pack_over_mp(tensor):
161+
mp_world_size = get_model_parallel_world_size()
162+
if mp_world_size == 1:
163+
return tensor # no-op for mp=1
164+
full_tensor_shape = list(tensor.shape)
165+
shard = tensor.view(-1).chunk(mp_world_size, dim=0)[get_model_parallel_rank()]
166+
shard = shard.detach().clone().contiguous() # clone to explicitly release memory of the full tensor
167+
del tensor
168+
return shard, full_tensor_shape
169+
170+
171+
def _unpack_over_mp(sharded_tensor):
172+
sharded_tensor, full_tensor_shape = sharded_tensor
173+
mp_world_size = get_model_parallel_world_size()
174+
if mp_world_size == 1:
175+
return sharded_tensor # no-op for mp=1
176+
full_tensor = torch.empty(
177+
*full_tensor_shape,
178+
dtype=sharded_tensor.dtype,
179+
device=sharded_tensor.device)
180+
181+
torch.distributed.all_gather_into_tensor(
182+
full_tensor.view(-1), sharded_tensor, group=get_model_parallel_group()
183+
)
184+
185+
return full_tensor
186+
187+
188+
class shard_over_mp_group(saved_tensors_hooks):
189+
"""Context manager for activatoin sharding.
190+
191+
This context manager shard tensors saved by autograd over the
192+
model parallel group in the forward pass and unshards them
193+
in the backward pass. Useful to remove redundancy in the
194+
long-living activation tensors.
195+
"""
196+
197+
def __init__(self):
198+
super().__init__(_pack_over_mp, _unpack_over_mp)

0 commit comments

Comments
 (0)