Skip to content

[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 27 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d1102ce
use torch.no_grad() to avoid calling cat() during FSDP backward excep…
chrisxcai Apr 29, 2024
9a22628
remove logging
chrisxcai Apr 29, 2024
f787532
logging
chrisxcai Apr 30, 2024
3429f33
logging
chrisxcai May 1, 2024
4b5abe2
use new field to accumulate per-parameter grads in fp32 and copy into…
chrisxcai May 2, 2024
c97bfd9
clean up accumulated fp32 grads between data batches
chrisxcai May 2, 2024
d2a88b7
logging
chrisxcai May 2, 2024
901fb86
logging
chrisxcai May 6, 2024
ad40f24
return grad in post_backward_hook()
chrisxcai May 8, 2024
14499fe
correct param_index
chrisxcai May 9, 2024
ad7aa1f
logging
chrisxcai May 9, 2024
b835770
add torch.testing.assert_allclose() to compare baseline and new grads
chrisxcai May 9, 2024
d689f38
logging
chrisxcai May 9, 2024
e8df583
logging
chrisxcai May 13, 2024
5926a79
honor optimize_backward_concat flag
chrisxcai May 15, 2024
5d08aa3
documentation
chrisxcai May 15, 2024
c91cb72
update documentation
chrisxcai May 15, 2024
fd3f3fc
update documentation
chrisxcai May 15, 2024
7678503
use grad instead of grad.data
chrisxcai May 15, 2024
c55a0d1
clean up
chrisxcai May 15, 2024
688b902
Added reshard hook for frozen params in backward
Jan 12, 2024
a3ff5c4
Avoid calling _free_fp16_param_shard() too early with PR 1159
jiecaoyu Feb 21, 2024
9d0e41e
Added requires_grad check for params_with_grad method (#1171)
whbldhwj Mar 25, 2024
e43a22f
Changed to only run reshard hook if all gradients computed (#1166)
awgu Apr 1, 2024
f039a3a
Add cast input argument (#1175)
whbldhwj Apr 5, 2024
5299982
honor optimize_backward_concat flag
chrisxcai May 15, 2024
b5e138f
use grad instead of grad.data
chrisxcai May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 189 additions & 22 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from enum import Enum, auto
import functools
import itertools
import logging
from math import inf
import os
Expand All @@ -27,6 +28,7 @@
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand All @@ -41,13 +43,13 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
ProcessGroupName,
chunk_and_pad,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
Expand Down Expand Up @@ -338,6 +340,11 @@ class FullyShardedDataParallel(nn.Module):
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
optimize_backward_concat (bool):
If True, only let backward pass propagate to self.params, which will
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync
is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why there will be more GPU memory usage?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @awgu, currently by testing results it shows the GPU memory overhead could be non-trivial (20% of 80G), we will follow up on reducing the memory usage
Screenshot 2024-05-15 at 10 39 19 AM
Screenshot 2024-05-15 at 10 40 18 AM

"""

def __init__(
Expand Down Expand Up @@ -368,7 +375,8 @@ def __init__(
gradient_predivide_factor: Optional[float] = None,
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
should_validate_process_group: bool = True,
cast_input: bool = True,
optimize_backward_concat: bool = False,
):
try:
import torch._C
Expand Down Expand Up @@ -419,6 +427,7 @@ def __init__(
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.cast_input = cast_input
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
self.move_params_to_cpu = move_params_to_cpu or cpu_offload
Expand Down Expand Up @@ -452,7 +461,7 @@ def __init__(
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")

# skip validation if the process group was created above
if process_group and should_validate_process_group:
if process_group:
validate_process_group(self.compute_device, self.process_group)

# enable pytorch sync_bn just in case model contains sync_bn layers.
Expand Down Expand Up @@ -493,8 +502,12 @@ def __init__(
param_name_groups = [param_names]
del param_names

self.optimize_backward_concat = optimize_backward_concat
if self.optimize_backward_concat:
assert self.fp32_reduce_scatter, f"{optimize_backward_concat=} requires self.fp32_reduce_scatter=True"

self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory, optimize_backward_concat=self.optimize_backward_concat,
)
del module # free original module in case it helps garbage collection

Expand Down Expand Up @@ -688,7 +701,7 @@ def _cast_buffers(
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if (p.grad is not None or p.main_grad is not None)]
return [p for p in self.parameters() if (p.requires_grad and (p.grad is not None or p.main_grad is not None))]

@torch.no_grad()
def clip_grad_norm_(
Expand Down Expand Up @@ -851,6 +864,7 @@ def extra_repr(self) -> str:
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"optimize_backward_concat={self.optimize_backward_concat}"
)
return repr

Expand Down Expand Up @@ -1099,12 +1113,20 @@ def no_sync(self) -> Generator:
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
if self.optimize_backward_concat:
# Set the flag on the wrapped FlattenParamsWrapper module as well,
# so that FlattenParamsWrapper could accumulate grads at corresponding
# leaf nodes without triggering concat operations when gradient
# synchronization is not needed.
m._fsdp_wrapped_module._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
if self.optimize_backward_concat:
m._fsdp_wrapped_module._require_backward_grad_sync = old_flag

@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
Expand Down Expand Up @@ -1430,7 +1452,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
if self._is_root and self.mixed_precision:
if self._is_root and self.mixed_precision and self.cast_input:
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs)

if self not in self._fsdp_forward_ordering:
Expand Down Expand Up @@ -1458,6 +1480,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()
self._register_post_backward_reshard_hooks(args, kwargs)

outputs = self.module(*args, **kwargs)

Expand Down Expand Up @@ -1656,6 +1679,34 @@ def _register_post_backward_hooks(self) -> None:
p._shard_bwd_hooks.append((grad_acc, handle))
# p._shard_bwd_hook = (grad_acc, handle)

def _register_post_backward_reshard_hooks(
self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> None:
if not torch.is_grad_enabled():
return
from torch.utils._pytree import tree_flatten
# Construct `inp_tensors` lazily to avoid CPU overhead in typical case
# where each parameter requires gradient
inp_tensors: Optional[List[torch.Tensor]] = None
for param in self.params:
# Only register for parameters that do not require gradient
if param.requires_grad:
continue
if inp_tensors is None:
args_list, _ = tree_flatten(args)
kwargs_list, _ = tree_flatten(kwargs)
inp_tensors = [
obj
for obj in itertools.chain(args_list, kwargs_list)
if torch.is_tensor(obj) and obj.requires_grad
]
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(self._post_backward_reshard_hook, param)
)
if not hasattr(param, "_shard_bwd_hooks"):
param._shard_bwd_hooks = []
param._shard_bwd_hooks.append((hook_handle,))

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
"""
Expand Down Expand Up @@ -1698,15 +1749,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")

if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
if self._should_free_in_backward():
# Free full params.
self._free_full_params([param])

if self.mixed_precision:
if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward):
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
Expand All @@ -1716,10 +1763,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
if self.optimize_backward_concat:
# Flatten and concat the accumulated fp32 grads
# and assign them to param.unsharded_main_grad
param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
# Clean up accumulated grads between data batches
self._fsdp_wrapped_module.fp32_grads = []
else:
param.unsharded_main_grad.add_(param.grad.data)
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
else:
param.unsharded_main_grad.add_(param.grad.data)

param.grad = None

Expand Down Expand Up @@ -1830,6 +1884,22 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
# Don't let this memory get reused until after the transfer.
reduced_grad.data.record_stream(torch.cuda.current_stream())

@torch.no_grad()
def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None:
if self._should_free_in_backward():
self._free_full_params([param])
if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward):
self._free_fp16_param_shard([param])
self._use_fp32_param_shard([param])

def _should_free_in_backward(self):
# As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
return self._require_backward_grad_sync or self.reshard_after_forward

def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.

Expand All @@ -1852,7 +1922,16 @@ def _wait_for_post_backward(self) -> None:
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in self.params]):
self.assert_state(TrainingState.BACKWARD_POST)
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)
else:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)

Expand All @@ -1879,16 +1958,24 @@ def _wait_for_post_backward(self) -> None:
def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.params:
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
# p._shard_bwd_hook[1].remove()
# delattr(p, "_shard_bwd_hook")
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync:
for _, handle in p._shard_bwd_hooks:
handle.remove()
for hook_state in p._shard_bwd_hooks:
if len(hook_state) == 1:
hook_state[0].remove()
elif len(hook_state) == 2:
hook_state[1].remove()
p._shard_bwd_hooks.clear()
if not p.requires_grad:
# For the 1st layer, if the forward inputs did not require
# gradient, then we cannot run a reshard hook for it, and
# we instead free here.
if p._is_sharded and p._full_param_padded.untyped_storage().size() > 0:
fsdp_module._post_backward_reshard_hook(p)
continue

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand Down Expand Up @@ -1929,7 +2016,16 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
m.assert_state(TrainingState.BACKWARD_POST)
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else:
Expand Down Expand Up @@ -2772,3 +2868,74 @@ def auto_wrap_bn(
enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress()
):
return auto_wrap(module)


class Handle(RemovableHandle):
handles: Tuple[RemovableHandle, ...]

def __init__(self, handles: Tuple[RemovableHandle, ...]):
self.handles = handles

def remove(self):
for handle in self.handles:
handle.remove()

def __getstate__(self):
return self.handles

def __setstate__(self, state):
self.handles = state


def register_multi_grad_hook(
tensors: Sequence[torch.Tensor],
fn: Callable[[Sequence[Optional[torch.Tensor]]], None]
):
count: Dict[int, int] = dict()
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()

grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
len_tensors = len(tensors)

def get_inner_hook(idx):
def inner_hook(grad: torch.Tensor):
nonlocal count, nb_calls, buffer, fn
id = torch._C._current_graph_task_id()
assert (
id != -1
), "expected this hook to be called inside a backward call"
count[id] = count.get(id, 0)
buffer[id] = buffer.get(id, [None] * len_tensors)

if count[id] == 0:
# On the first call, compute the actual nb_calls and buffer
# nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined]

# NOTE: To avoid resharding too early when microbatches share
# some same module inputs, let us require all gradients to be
# computed in this backward for the hook to run.
nb_calls = len(grad_fns)

buffer[id][idx] = grad
count[id] += 1

if count[id] == nb_calls:
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
fn(buffer[id])
del count[id]
del buffer[id]

return inner_hook

handles: Tuple[RemovableHandle, ...] = tuple(
t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
)
return Handle(handles)


def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
Loading