diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index cdd6e6e8c..13f1e5577 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from enum import Enum, auto import functools +import itertools import logging from math import inf import os @@ -27,6 +28,7 @@ Mapping, NamedTuple, Optional, + Sequence, Set, Tuple, Union, @@ -41,13 +43,13 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter +from torch.utils.hooks import RemovableHandle from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import ( ProcessGroupName, - chunk_and_pad, enable_pytorch_sync_bn, get_process_group_cached, validate_process_group, @@ -338,6 +340,11 @@ class FullyShardedDataParallel(nn.Module): rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM. Default: False + optimize_backward_concat (bool): + If True, only let backward pass propagate to self.params, which will + invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync + is True (e.g. last microbatch) + NOTE: this likely will incur more GPU memory usage """ def __init__( @@ -368,7 +375,8 @@ def __init__( gradient_predivide_factor: Optional[float] = None, limit_all_gather_events: bool = False, limit_reduce_scatter_events: bool = False, - should_validate_process_group: bool = True, + cast_input: bool = True, + optimize_backward_concat: bool = False, ): try: import torch._C @@ -419,6 +427,7 @@ def __init__( self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward self.disable_reshard_on_root = disable_reshard_on_root self.mixed_precision = mixed_precision + self.cast_input = cast_input self.fp32_reduce_scatter = fp32_reduce_scatter self.flatten_parameters = flatten_parameters self.move_params_to_cpu = move_params_to_cpu or cpu_offload @@ -452,7 +461,7 @@ def __init__( raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True") # skip validation if the process group was created above - if process_group and should_validate_process_group: + if process_group: validate_process_group(self.compute_device, self.process_group) # enable pytorch sync_bn just in case model contains sync_bn layers. @@ -493,8 +502,12 @@ def __init__( param_name_groups = [param_names] del param_names + self.optimize_backward_concat = optimize_backward_concat + if self.optimize_backward_concat: + assert self.fp32_reduce_scatter, f"{optimize_backward_concat=} requires self.fp32_reduce_scatter=True" + self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( - module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory + module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory, optimize_backward_concat=self.optimize_backward_concat, ) del module # free original module in case it helps garbage collection @@ -688,7 +701,7 @@ def _cast_buffers( @property def params_with_grad(self) -> List[Parameter]: """[p for p in self.parameters() if p.grad is not None]""" - return [p for p in self.parameters() if (p.grad is not None or p.main_grad is not None)] + return [p for p in self.parameters() if (p.requires_grad and (p.grad is not None or p.main_grad is not None))] @torch.no_grad() def clip_grad_norm_( @@ -851,6 +864,7 @@ def extra_repr(self) -> str: f"bucket_cap_mb={self.bucket_cap_mb}, " f"clear_autocast_cache={self.clear_autocast_cache}" f"force_input_to_fp32={self.force_input_to_fp32}" + f"optimize_backward_concat={self.optimize_backward_concat}" ) return repr @@ -1099,12 +1113,20 @@ def no_sync(self) -> Generator: if isinstance(m, FullyShardedDataParallel): old_flags.append((m, m._require_backward_grad_sync)) m._require_backward_grad_sync = False + if self.optimize_backward_concat: + # Set the flag on the wrapped FlattenParamsWrapper module as well, + # so that FlattenParamsWrapper could accumulate grads at corresponding + # leaf nodes without triggering concat operations when gradient + # synchronization is not needed. + m._fsdp_wrapped_module._require_backward_grad_sync = False try: yield finally: for m, old_flag in old_flags: assert m._require_backward_grad_sync is False m._require_backward_grad_sync = old_flag + if self.optimize_backward_concat: + m._fsdp_wrapped_module._require_backward_grad_sync = old_flag @contextlib.contextmanager def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator: @@ -1430,7 +1452,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # For root and mixed precision, we convert the input to FP16 (no_grad is needed for # the conversion). is_bf16 = self.compute_dtype == torch.bfloat16 - if self._is_root and self.mixed_precision: + if self._is_root and self.mixed_precision and self.cast_input: args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs) if self not in self._fsdp_forward_ordering: @@ -1458,6 +1480,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # Register backward hooks to reshard params and reduce-scatter grads. # These need to be re-registered every forward pass. self._register_post_backward_hooks() + self._register_post_backward_reshard_hooks(args, kwargs) outputs = self.module(*args, **kwargs) @@ -1656,6 +1679,34 @@ def _register_post_backward_hooks(self) -> None: p._shard_bwd_hooks.append((grad_acc, handle)) # p._shard_bwd_hook = (grad_acc, handle) + def _register_post_backward_reshard_hooks( + self, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> None: + if not torch.is_grad_enabled(): + return + from torch.utils._pytree import tree_flatten + # Construct `inp_tensors` lazily to avoid CPU overhead in typical case + # where each parameter requires gradient + inp_tensors: Optional[List[torch.Tensor]] = None + for param in self.params: + # Only register for parameters that do not require gradient + if param.requires_grad: + continue + if inp_tensors is None: + args_list, _ = tree_flatten(args) + kwargs_list, _ = tree_flatten(kwargs) + inp_tensors = [ + obj + for obj in itertools.chain(args_list, kwargs_list) + if torch.is_tensor(obj) and obj.requires_grad + ] + hook_handle = register_multi_grad_hook( + inp_tensors, functools.partial(self._post_backward_reshard_hook, param) + ) + if not hasattr(param, "_shard_bwd_hooks"): + param._shard_bwd_hooks = [] + param._shard_bwd_hooks.append((hook_handle,)) + @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: """ @@ -1698,15 +1749,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if param.grad.requires_grad: raise RuntimeError("FSDP only works with gradients that don't require gradients") - if self._require_backward_grad_sync or self.reshard_after_forward: - # Free full params. As a special case, we don't free the full params - # when in a ``no_sync`` context (as inversely indicated by - # ``self._require_backward_grad_sync``), since the params will not - # get updated before the next forward. This saves networking - # bandwidth but uses more GPU memory. + if self._should_free_in_backward(): + # Free full params. self._free_full_params([param]) - if self.mixed_precision: + if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward): # This is a no-op if reshard_after_forward is True, since we already # free the param shard when rebuilding the full params in the # pre_backward_hook. @@ -1716,10 +1763,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self._use_fp32_param_shard([param]) if self.fp32_reduce_scatter: - if getattr(param, "unsharded_main_grad", None) is None: - param.unsharded_main_grad = param.grad.to(torch.float32) + if self.optimize_backward_concat: + # Flatten and concat the accumulated fp32 grads + # and assign them to param.unsharded_main_grad + param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads]) + # Clean up accumulated grads between data batches + self._fsdp_wrapped_module.fp32_grads = [] else: - param.unsharded_main_grad.add_(param.grad.data) + if getattr(param, "unsharded_main_grad", None) is None: + param.unsharded_main_grad = param.grad.to(torch.float32) + else: + param.unsharded_main_grad.add_(param.grad.data) param.grad = None @@ -1830,6 +1884,22 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> # Don't let this memory get reused until after the transfer. reduced_grad.data.record_stream(torch.cuda.current_stream()) + @torch.no_grad() + def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None: + if self._should_free_in_backward(): + self._free_full_params([param]) + if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward): + self._free_fp16_param_shard([param]) + self._use_fp32_param_shard([param]) + + def _should_free_in_backward(self): + # As a special case, we don't free the full params + # when in a ``no_sync`` context (as inversely indicated by + # ``self._require_backward_grad_sync``), since the params will not + # get updated before the next forward. This saves networking + # bandwidth but uses more GPU memory. + return self._require_backward_grad_sync or self.reshard_after_forward + def _queue_wait_for_post_backward(self) -> None: """Try to queue a `wait_for_post_backward` callback. @@ -1852,7 +1922,16 @@ def _wait_for_post_backward(self) -> None: # all the params, the post_backward hook will not fire and the # state will remain in `TrainingState.BACKWARD_PRE`. if any([p.requires_grad for p in self.params]): - self.assert_state(TrainingState.BACKWARD_POST) + if self.optimize_backward_concat: + # If self.optimize_backward_concat==True, FSDP backward should + # only be triggered (which will invoke concat()) + # when self._fsdp_wrapped_module._require_backward_grad_sync = True + if self._fsdp_wrapped_module._require_backward_grad_sync: + self.assert_state(TrainingState.BACKWARD_POST) + else: + self.assert_state(TrainingState.BACKWARD_PRE) + else: + self.assert_state(TrainingState.BACKWARD_POST) else: self.assert_state(TrainingState.BACKWARD_PRE) @@ -1879,16 +1958,24 @@ def _wait_for_post_backward(self) -> None: def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: """Helper used below on all fsdp modules.""" for p in fsdp_module.params: - if not p.requires_grad: - continue if hasattr(p, "_shard_bwd_hook"): p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}") # p._shard_bwd_hook[1].remove() # delattr(p, "_shard_bwd_hook") if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync: - for _, handle in p._shard_bwd_hooks: - handle.remove() + for hook_state in p._shard_bwd_hooks: + if len(hook_state) == 1: + hook_state[0].remove() + elif len(hook_state) == 2: + hook_state[1].remove() p._shard_bwd_hooks.clear() + if not p.requires_grad: + # For the 1st layer, if the forward inputs did not require + # gradient, then we cannot run a reshard hook for it, and + # we instead free here. + if p._is_sharded and p._full_param_padded.untyped_storage().size() > 0: + fsdp_module._post_backward_reshard_hook(p) + continue # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad # remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard @@ -1929,7 +2016,16 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: # all the params, the post_backward hook will not fire and the # state will remain in `TrainingState.BACKWARD_PRE`. if any([p.requires_grad for p in m.params]): - m.assert_state(TrainingState.BACKWARD_POST) + if self.optimize_backward_concat: + # If self.optimize_backward_concat==True, FSDP backward should + # only be triggered (which will invoke concat()) + # when self._fsdp_wrapped_module._require_backward_grad_sync = True + if self._fsdp_wrapped_module._require_backward_grad_sync: + m.assert_state(TrainingState.BACKWARD_POST) + else: + m.assert_state(TrainingState.BACKWARD_PRE) + else: + m.assert_state(TrainingState.BACKWARD_POST) else: m.assert_state(TrainingState.BACKWARD_PRE) else: @@ -2772,3 +2868,74 @@ def auto_wrap_bn( enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress() ): return auto_wrap(module) + + +class Handle(RemovableHandle): + handles: Tuple[RemovableHandle, ...] + + def __init__(self, handles: Tuple[RemovableHandle, ...]): + self.handles = handles + + def remove(self): + for handle in self.handles: + handle.remove() + + def __getstate__(self): + return self.handles + + def __setstate__(self, state): + self.handles = state + + +def register_multi_grad_hook( + tensors: Sequence[torch.Tensor], + fn: Callable[[Sequence[Optional[torch.Tensor]]], None] +): + count: Dict[int, int] = dict() + nb_calls = None + buffer: Dict[int, List[Optional[torch.Tensor]]] = dict() + + grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors)) + len_tensors = len(tensors) + + def get_inner_hook(idx): + def inner_hook(grad: torch.Tensor): + nonlocal count, nb_calls, buffer, fn + id = torch._C._current_graph_task_id() + assert ( + id != -1 + ), "expected this hook to be called inside a backward call" + count[id] = count.get(id, 0) + buffer[id] = buffer.get(id, [None] * len_tensors) + + if count[id] == 0: + # On the first call, compute the actual nb_calls and buffer + # nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined] + + # NOTE: To avoid resharding too early when microbatches share + # some same module inputs, let us require all gradients to be + # computed in this backward for the hook to run. + nb_calls = len(grad_fns) + + buffer[id][idx] = grad + count[id] += 1 + + if count[id] == nb_calls: + fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn) + fn(buffer[id]) + del count[id] + del buffer[id] + + return inner_hook + + handles: Tuple[RemovableHandle, ...] = tuple( + t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors) + ) + return Handle(handles) + + +def _get_grad_fn_or_grad_acc(t): + if t.requires_grad and t.grad_fn is None: + return t.view_as(t).grad_fn.next_functions[0][0] + else: + return t.grad_fn diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 38265dd2b..7ef9ea1a9 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -7,6 +7,7 @@ # Licensed under the MIT License. from contextlib import contextmanager +import functools from itertools import chain import tempfile import typing @@ -33,11 +34,11 @@ from fairscale.experimental.nn.ssd_offload import SsdFlatParameter from fairscale.utils.state_dict import replace_by_prefix_ +import functools if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 - class FlatParameter(nn.Parameter): """A parameter that is initialized from a list of parameters and can be turned into a list of views as needed. @@ -90,7 +91,9 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te raise ValueError( f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}" ) - return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes)) + + split_outputs = data.split(self._param_numels) + return (t.view(s) for (t, s) in zip(split_outputs, self._param_shapes)) def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]: """Return tuple of (names, shapes, numels) metadata for this flat parameter.""" @@ -148,6 +151,11 @@ class FlattenParamsWrapper(nn.Module): flat_param_names (Optional[List[str]]): originally, give each flat_param a unique name. Note a "flat_param_" prefix will be added to those names. + optimize_backward_concat (bool): + If True, only let backward pass propagate to the corresponding FSDP.params, which will + invoke the FSDP._post_backward_hook() and concat() op, when _require_backward_grad_sync + is True (e.g. last microbatch) + NOTE: this likely will incur more GPU memory usage """ def __init__( @@ -157,10 +165,18 @@ def __init__( flat_param_names: Optional[List[str]] = None, ssd_offload: bool = False, ssd_directory: str = "", + optimize_backward_concat: bool = False, ): super().__init__() self._fpw_module = module self.is_flattened = False + self.optimize_backward_concat = optimize_backward_concat + # If optimize_backward_concat == True, used to propagate the + # corresponding FSDP modules's _require_backward_grad_sync flag + self._require_backward_grad_sync = True + # If optimize_backward_concat == True, used to accumulate the + # fp32 gradients for the flattened parameters + self.fp32_grads = [] # Handle param_list being None. if param_list is None: @@ -364,18 +380,60 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No delattr(self, n) self.flat_params = [] + # The post backward hook used to accumulate fp32 gradients + def _grad_accumulation_hook( + self, + grad, + param_index, + ): + if self.fp32_grads[param_index] is None: + self.fp32_grads[param_index] = grad.to(torch.float32) + else: + self.fp32_grads[param_index].add_(grad) + return grad + def _unflatten_params_as_views(self) -> None: """Unlike ``_unflatten_params``, this function unflatten into views and keep self.flat_param unchanged. """ assert self.is_flattened - ps = self.get_param_views() + if self.optimize_backward_concat: + # If self._require_backward_grad_sync == True (e.g. last microbatch), + # we use the original flat_params as autograd leaf nodes and backward + # pass should propagate all the way back to FSDP module and thus invoke + # FSDP post_backward() hook and concat() op + # Otherwise we stop the backward propagation before FSDP module to avoid + # invoking concat() and store the accumulated fp32 grads + if self._require_backward_grad_sync: + ps = self.get_param_views() + else: + with torch.no_grad(): + ps = self.get_param_views() + else: + ps = self.get_param_views() + param_views = [] for (_, m, n), p in zip(self._param_infos, ps): setattr(p, '_fsdp_weight', True) setattr(m, n, p) # This will set as plain attr + # The param_index of p used to accumulate the correspnding + # gradients in self.fp32_grads + param_index = len(param_views) + if self.optimize_backward_concat: + # Register post backward hook to accumulate the gradients + # in self.fp32_grads + p.register_hook( + functools.partial( + self._grad_accumulation_hook, + param_index=param_index + ) + ) param_views.append(p) + if self.optimize_backward_concat and len(self.fp32_grads) == 0: + # Allocate self.fp32_grads at the beginning of each data batch's forward() + self.fp32_grads = [None] * len(param_views) + # Save param views for easy access if anyone still wants to access # parameters of the module. setattr(self._fpw_module, "_unflattened_param_views", param_views) diff --git a/tests/nn/data_parallel/test_fsdp_freezing_weights.py b/tests/nn/data_parallel/test_fsdp_freezing_weights.py index c6ad364f7..7baadc5d9 100644 --- a/tests/nn/data_parallel/test_fsdp_freezing_weights.py +++ b/tests/nn/data_parallel/test_fsdp_freezing_weights.py @@ -12,6 +12,8 @@ from enum import Enum from itertools import product +from unittest import mock +import copy import tempfile import pytest @@ -275,3 +277,97 @@ def test_freezing_weights(temp_files, nested_trunk): nprocs=world_size, ) temp_file_idx += 3 + + +@skip_if_single_gpu +def test_reshard_frozen_weights(): + world_size = 2 + for flatten_parameters, reshard_after_forward, inp_requires_grad in product( + [False, True], [False, True], [False, True] + ): + print( + "Testing FSDP reshard frozen weights with " + f"flatten_parameters={flatten_parameters}, " + f"reshard_after_forward={reshard_after_forward}, " + f"inp_requires_grad={inp_requires_grad}" + ) + mp.spawn( + _distributed_worker_reshard, + (world_size, flatten_parameters, reshard_after_forward, inp_requires_grad), + nprocs=world_size, + ) + + +def _distributed_worker_reshard( + rank: int, + world_size: int, + flatten_parameters: bool, + reshard_after_forward: bool, + inp_requires_grad: bool, +): + import os + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.cuda.set_device(rank) + torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + torch.manual_seed(0) + + num_linears = 6 + modules = [] + for _ in range(num_linears): + modules += [nn.Linear(5, 5, device="cuda"), nn.ReLU()] + model = nn.Sequential(*modules) + # Freeze every other linear + for i in range(num_linears): + if i % 2 == 0: + for param in model[i * 2].parameters(recurse=False): + param.requires_grad = False + num_frozen_linears = num_linears // 2 + + ref_model = DistributedDataParallel(copy.deepcopy(model), device_ids=[rank]) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + + for i, module in enumerate(model): + if isinstance(module, nn.Linear): + model[i] = FSDP( + module, + flatten_parameters=flatten_parameters, + reshard_after_forward=reshard_after_forward, + ) + fsdp_model = FSDP( + model, + flatten_parameters=flatten_parameters, + reshard_after_forward=reshard_after_forward, + ) + fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) + + orig_post_backward_reshard_hook = FSDP._post_backward_reshard_hook + reshard_hook_count = 0 + + def post_backward_reshard_hook_with_count(*args, **kwargs): + nonlocal reshard_hook_count + reshard_hook_count += 1 + return orig_post_backward_reshard_hook(*args, **kwargs) + + with mock.patch( + "fairscale.nn.data_parallel.FullyShardedDataParallel._post_backward_reshard_hook", + post_backward_reshard_hook_with_count, + ): + inp = torch.randn((8, 5), device="cuda", requires_grad=inp_requires_grad) + for i in range(6): + losses = [] + for model, optim in ((fsdp_model, fsdp_optim), (ref_model, ref_optim)): + optim.zero_grad() + loss = model(inp).sum() + losses.append(loss) + loss.backward() + optim.step() + expected_reshard_hook_count = num_frozen_linears + if not flatten_parameters: + expected_reshard_hook_count *= 2 # weight and bias per linear + assert ( + reshard_hook_count == expected_reshard_hook_count + ), f"Expected {expected_reshard_hook_count} but got {reshard_hook_count}" + assert losses[0].eq(losses[1]).all().item(), f"Expected {losses[1]} but got {losses[0]}" + reshard_hook_count = 0