Skip to content

Commit

Permalink
[doc] FSDP improvements (#2274)
Browse files Browse the repository at this point in the history
* Update fsdp.md

* fix typo

* fix readability

* resolve the "static models" ambiguity

* rewrite section

* typo
  • Loading branch information
stas00 authored Dec 27, 2023
1 parent d1abd59 commit 3db088f
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions docs/source/usage_guides/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ default options when doing
accelerate launch my_script.py --args_to_my_script
```

For instance, here is how you would run the NLP example (from the root of the repo) with FSDP enabled:
For instance, here is how you would run `examples/nlp_example.py` (from the root of the repo) with FSDP enabled:

```bash
compute_environment: LOCAL_MACHINE
Expand All @@ -46,8 +46,8 @@ downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: true
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: SHARDED_STATE_DICT
Expand All @@ -73,42 +73,30 @@ accelerate launch examples/nlp_example.py

Currently, `Accelerate` supports the following config through the CLI:

`fsdp_sharding_strategy`: [1] FULL_SHARD (shards optimizer states, gradients and parameters), [2] SHARD_GRAD_OP (shards optimizer states and gradients), [3] NO_SHARD (DDP), [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy)

`Sharding Strategy`: [1] FULL_SHARD (shards optimizer states, gradients and parameters), [2] SHARD_GRAD_OP (shards optimizer states and gradients), [3] NO_SHARD (DDP), [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy)
`fsdp_offload_params` : Decides Whether to offload parameters and gradients to CPU

`Offload Params`: Decides Whether to offload parameters and gradients to CPU
`fsdp_auto_wrap_policy`: [1] TRANSFORMER_BASED_WRAP, [2] SIZE_BASED_WRAP, [3] NO_WRAP

`Auto Wrap Policy`: [1] TRANSFORMER_BASED_WRAP, [2] SIZE_BASED_WRAP, [3] NO_WRAP
`fsdp_transformer_layer_cls_to_wrap`: Only applicable for 🤗 Transformers. When using `fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP`, a user may provide a comma-separated string of transformer layer class names (case-sensitive) to wrap, e.g., `BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput`. This is important because submodules that share weights (e.g., embedding layers) should not end up in different FSDP wrapped units. Using this policy, wrapping happens for each block containing Multi-Head Attention followed by a couple of MLP layers. Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. Therefore, use this for transformer-based models. You can use the `model._no_split_modules` for 🤗 Transformer models by answering `yes` to `Do you want to use the model's `_no_split_modules` to wrap. It will try to use `model._no_split_modules` when possible.

`Transformer Layer Class to Wrap`: When using `TRANSFORMER_BASED_WRAP`, user specifies comma-separated string of transformer layer class names (case-sensitive) to wrap ,e.g,
`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput`...
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit.
Therefore, use this for transformer based models.
You can use the `model._no_split_modules` for 🤗 Transformer models by answering `yes` to
`Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers`.
It will try to use `model._no_split_modules` when available.
`fsdp_min_num_params`: minimum number of parameters when using `fsdp_auto_wrap_policy=SIZE_BASED_WRAP`.

`Min Num Params`: minimum number of parameters when using `SIZE_BASED_WRAP`
`fsdp_backward_prefetch_policy`: [1] BACKWARD_PRE, [2] BACKWARD_POST, [3] NO_PREFETCH

`Backward Prefetch`: [1] BACKWARD_PRE, [2] BACKWARD_POST, [3] NO_PREFETCH
`fsdp_forward_prefetch`: if True, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. Should only be used for static-graph models since the prefetching follows the first iteration’s execution order. i.e., if the sub-modules' order changes dynamically during the model's executation do not enable this feature.

`State Dict Type`: [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT
`fsdp_state_dict_type`: [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT

`Forward Prefetch`: if True, then FSDP explicitly prefetches the next upcoming
all-gather while executing in the forward pass. only use with Static graphs.
`fsdp_use_orig_params`: If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. This setting is useful in cases such as parameter-efficient fine-tuning as discussed in [this post](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). This option also allows one to have multiple optimizer param groups. This should be `True` when creating an optimizer before preparing/wrapping the model with FSDP.

`Use Orig Params`: If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres.
Useful in cases such as parameter-efficient fine-tuning.
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP.
`fsdp_cpu_ram_efficient_loading`: Only applicable for 🤗 Transformers models. If True, only the first process loads the pretrained model checkpoint while all other processes have empty weights. This should be set to False if you experience errors when loading the pretrained 🤗 Transformers model via `from_pretrained` method. When this setting is True `fsdp_sync_module_states` also must to be True, otherwise all the processes except the main process would have random weights leading to unexpected behaviour during training.

`CPU RAM Efficient Model loading`: If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for 🤗 Transformers models. This should be set to False if you experience errors when loading the pretrained 🤗 Transformers model via `from_pretrained` method. When using this, `Sync Module States` needs to be True else all the processes expect the main process would have random empty weights leading to unexpected behaviour during training.
`fsdp_sync_module_states`: If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0.

`Sync Module States`: If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0


For additional and more nuanced control, you can specify other FSDP parameters via `FullyShardedDataParallelPlugin`.
For additional and more nuanced control, you can specify other FSDP parameters via `FullyShardedDataParallelPlugin`.
When creating `FullyShardedDataParallelPlugin` object, pass it the parameters that weren't part of the accelerate config or if you want to override them.
The FSDP parameters will be picked based on the accelerate config file or launch command arguments and other parameters that you will pass directly through the `FullyShardedDataParallelPlugin` object will set/override that.

Expand Down Expand Up @@ -137,7 +125,7 @@ accelerator.save_state("ckpt")

Inspect the ckeckpoint folder to see model and optimizer as shards per process:
```
ls ckpt
ls ckpt
# optimizer_0 pytorch_model_0 random_states_0.pkl random_states_1.pkl scheduler.bin
cd ckpt
Expand All @@ -155,7 +143,7 @@ To load them back for resuming the training, use the `load_state` utility of acc
accelerator.load_state("ckpt")
```

When using transformers `save_pretrained`, pass `state_dict=accelerator.get_state_dict(model)` to save the model state dict.
When using transformers `save_pretrained`, pass `state_dict=accelerator.get_state_dict(model)` to save the model state dict.
Below is an example:

```diff
Expand Down

0 comments on commit 3db088f

Please sign in to comment.