Skip to content

Commit 0b77de4

Browse files
author
Naman Goyal
committed
added option to do backward AG over smaller set of gpus instead of full DDP
1 parent ba38cf3 commit 0b77de4

File tree

1 file changed

+180
-2
lines changed

1 file changed

+180
-2
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Lines changed: 180 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def __init__(
332332
offload_config: Optional[OffloadConfig] = None,
333333
state_dict_on_rank_0_only: bool = False,
334334
gradient_predivide_factor: Optional[float] = None,
335+
zero2_process_group: Optional[ProcessGroup] = None,
335336
):
336337
try:
337338
import torch._C
@@ -380,6 +381,9 @@ def __init__(
380381
"parameter uses all the available ranks for the optimal performance."
381382
)
382383
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
384+
385+
self.zero2_process_group = zero2_process_group
386+
383387
self.disable_reshard_on_root = disable_reshard_on_root
384388
self.mixed_precision = mixed_precision
385389
self.fp32_reduce_scatter = fp32_reduce_scatter
@@ -518,6 +522,9 @@ def __init__(
518522
if isinstance(m, FullyShardedDataParallel):
519523
m._free_ssd_offload()
520524

525+
if self.zero2_process_group is not None:
526+
assert not self.move_params_to_cpu
527+
521528
def _get_gradient_predivide_factor(self, world_size: int) -> float:
522529
factor: int = 1
523530
while world_size % factor == 0 and world_size / factor > factor:
@@ -1419,7 +1426,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
14191426
outputs = self.module(*args, **kwargs)
14201427

14211428
if self.reshard_after_forward:
1422-
self._free_full_params()
1429+
if self.zero2_process_group is not None:
1430+
self._zero2_shard_to_smaller_group()
1431+
else:
1432+
self._free_full_params()
14231433
if self.mixed_precision or self.move_params_to_cpu:
14241434
self._free_fp16_param_shard()
14251435

@@ -1499,7 +1509,10 @@ def _pre_backward_hook(*unused: Any) -> None:
14991509
# idempotent. So in case they are called unnecessarily, they don't incur much
15001510
# overhead.
15011511
if self.reshard_after_forward:
1502-
self._rebuild_full_params()
1512+
if self.zero2_process_group is not None:
1513+
self._zero2_rebuild_full_params()
1514+
else:
1515+
self._rebuild_full_params()
15031516
if (
15041517
self.reshard_after_forward
15051518
and self._fsdp_forward_ordering is not None
@@ -2006,6 +2019,126 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
20062019
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
20072020
return output_tensors
20082021

2022+
2023+
@torch.no_grad()
2024+
def _zero2_rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]:
2025+
"""
2026+
Gather all shards of params.
2027+
2028+
Note, this is idempotent if full params are already gathered. Callers
2029+
assume the idempotency. So please keep it that way.
2030+
2031+
Args:
2032+
force_full_precision (bool, Optional): by default params will be gathered
2033+
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
2034+
``True``, in which case they will be gathered in full precision
2035+
(e.g., FP32), possibly in fresh storage. The parameter that's being
2036+
rebuilt will end up in full precision as well.
2037+
2038+
Returns:
2039+
A list of tuples, where the first element is the full-sized param
2040+
and the second element is a bool indicating if it's safe for the
2041+
caller to free the full-sized param. This will be ``None`` if
2042+
``force_full_precision=False`` and the full params are already gathered.
2043+
"""
2044+
output_tensors: List[Tuple[torch.Tensor, bool]] = []
2045+
2046+
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
2047+
"""
2048+
Helper function to update p.data pointer.
2049+
2050+
Args:
2051+
custom_output_tensor (torch.Tensor, Optional): if not None, this
2052+
tensor contains the data we just gathered.
2053+
"""
2054+
if custom_output_tensor is not None:
2055+
assert p._is_sharded
2056+
p.data = custom_output_tensor
2057+
output_tensors.append((p.data, True))
2058+
elif not p._is_sharded:
2059+
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
2060+
assert p._fp16_shard is not None
2061+
p.data = p._fp16_shard
2062+
output_tensors.append((p.data, True))
2063+
else:
2064+
# Here p.data == p._fp32_shard, so it's not safe to free.
2065+
output_tensors.append((p.data, False))
2066+
else:
2067+
p.data = p._full_param_padded
2068+
output_tensors.append((p.data, True))
2069+
# Trim any padding and reshape to match original size.
2070+
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
2071+
2072+
if self._has_shared_params:
2073+
# self.has_full_params flag can be out of sync if a shared param is
2074+
# sharded by another FSDP instance. An example is that in eval case
2075+
# with reshard_after_forward=False but the sharing instance has
2076+
# reshard_after_forward=True. Then, on the second forward, the
2077+
# other instance can shard the shared param and but this instance
2078+
# can mistakenly think the full param is already gathered from the
2079+
# has_full_params flag.
2080+
#
2081+
# Therefore, we update the flag accordingly here.
2082+
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)
2083+
2084+
# Early exit if we already have full params and don't need full precision.
2085+
if self.has_full_params and not force_full_precision:
2086+
if wait_for_all_gather:
2087+
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
2088+
for p in self.params:
2089+
update_p_data()
2090+
return output_tensors
2091+
2092+
self.has_full_params = True
2093+
2094+
with torch.cuda.stream(self._streams["all_gather"]):
2095+
2096+
for p in self.params:
2097+
if not p._is_sharded: # e.g., when world_size == 1
2098+
update_p_data()
2099+
else:
2100+
# Skip if already built. Only shared param can be rebuilt multiple times.
2101+
# A corner case is p._orig_size = (1,), which means the shape equality is
2102+
# not a perfect check. But we assume we don't share a param with shape (1,).
2103+
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
2104+
continue
2105+
# If self.move_params_to_cpu and force_full_precision, we need to cast
2106+
# the FP32 CPU param to CUDA for the all-gather.
2107+
p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
2108+
2109+
p_size = p._full_param_padded.size()
2110+
assert p_size.numel() % self.world_size == 0
2111+
if self.mixed_precision and force_full_precision:
2112+
# Allocate fresh tensor in full precision since we are in
2113+
# mixed precision and full precision rebuild is asked.
2114+
output_tensor = p_data.new_zeros(p_size)
2115+
else:
2116+
if p._full_param_padded.storage().size() != p_size.numel():
2117+
# Allocate based on full size from all shards.
2118+
alloc_storage_(p._full_param_padded, size=p_size)
2119+
output_tensor = p._full_param_padded
2120+
2121+
# Fill output_tensor with (p.data for each shard in self.world_size)
2122+
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
2123+
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
2124+
dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group)
2125+
else:
2126+
chunks = list(output_tensor.chunk(self.world_size))
2127+
dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group)
2128+
2129+
# Set p.data = output_tensor (with padding trimmed)
2130+
update_p_data(output_tensor)
2131+
2132+
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
2133+
self._free_zero2_param_shard([p])
2134+
2135+
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
2136+
self._free_zero2_param_shard([p])
2137+
if wait_for_all_gather:
2138+
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
2139+
return output_tensors
2140+
2141+
20092142
@torch.no_grad()
20102143
def _use_full_params(self) -> None:
20112144
"""
@@ -2074,6 +2207,38 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
20742207
free_storage_(p._full_param_padded)
20752208
torch.cuda.current_stream().synchronize()
20762209

2210+
2211+
def _zero2_shard_to_smaller_group(self, params: Optional[List[Parameter]] = None):
2212+
if params is None:
2213+
params = self.params
2214+
self.has_full_params = False
2215+
current_stream = torch.cuda.current_stream()
2216+
for p in params:
2217+
if not p._is_sharded: # e.g., world_size == 1
2218+
if self.mixed_precision or self.move_params_to_cpu:
2219+
self._free_fp16_param_shard([p])
2220+
continue
2221+
# Cases for when zero2 world size > 1 but less than zero3 size
2222+
zero2_world_size = dist.get_world_size(self.zero2_process_group)
2223+
zero2_rank = dist.get_rank(self.zero2_process_group)
2224+
chunks = p._full_param_padded.chunk(zero2_world_size)
2225+
2226+
p._zero2_fp16_shard = torch.empty_like(chunks[zero2_rank])
2227+
p._zero2_fp16_shard.copy_(chunks[zero2_rank])
2228+
2229+
# Don't let PyTorch reuse this memory until all work in the current
2230+
# stream is complete.
2231+
p._full_param_padded.record_stream(current_stream)
2232+
# There may be external references to the Tensor Storage that we
2233+
# can't modify, such as references that are created by
2234+
# ctx.save_for_backward in the forward pass. Thus when we
2235+
# unshard parameters, we should reuse the original Tensor
2236+
# Storage object and unshard it in-place. For now, just resize
2237+
# the Storage to 0 to save memory.
2238+
free_storage_(p._full_param_padded)
2239+
torch.cuda.current_stream().synchronize()
2240+
2241+
20772242
def local_metadata_dict(self) -> Dict[str, Any]:
20782243
"""
20792244
Get the information needed to reconstruct the model from shards offline.
@@ -2238,6 +2403,19 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No
22382403
p._fp16_shard.record_stream(current_stream)
22392404
free_storage_(p._fp16_shard)
22402405

2406+
@torch.no_grad()
2407+
def _free_zero2_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
2408+
"""Free storage for FP16 shards for a list of params."""
2409+
if params is None:
2410+
params = self.params
2411+
current_stream = torch.cuda.current_stream()
2412+
for p in params:
2413+
if p._zero2_fp16_shard is not None:
2414+
# _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
2415+
# free it until the work in the current stream completes.
2416+
p._zero2_fp16_shard.record_stream(current_stream)
2417+
free_storage_(p._zero2_fp16_shard)
2418+
22412419
def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
22422420
"""Assert we are in the given state."""
22432421
# Since assert can be turned off and this error checking

0 commit comments

Comments
 (0)