Skip to content

Commit

Permalink
Improve FSDP config usability (#2288)
Browse files Browse the repository at this point in the history
* Improve FSDP config usability

* quality ✨

* Update tests

* fix cmd arg

* fix

* update docs

* address comments
  • Loading branch information
pacman100 authored Dec 27, 2023
1 parent ad957ce commit 848ed80
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/source/usage_guides/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: true
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
Expand Down
5 changes: 2 additions & 3 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,7 @@ def get_cluster_input():
fsdp_config["fsdp_sharding_strategy"] = _ask_options(
sharding_strategy_query,
FSDP_SHARDING_STRATEGY,
lambda x: int(x) + 1,
default=1,
lambda x: FSDP_SHARDING_STRATEGY[int(x)],
)
fsdp_config["fsdp_offload_params"] = _ask_field(
"Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
Expand Down Expand Up @@ -362,7 +361,7 @@ def get_cluster_input():
default=100000000,
)
fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?"
fsdp_config["fsdp_backward_prefetch_policy"] = _ask_options(
fsdp_config["fsdp_backward_prefetch"] = _ask_options(
fsdp_backward_prefetch_query,
FSDP_BACKWARD_PREFETCH,
lambda x: FSDP_BACKWARD_PREFETCH[int(x)],
Expand Down
10 changes: 8 additions & 2 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ def launch_command_parser(subparsers=None):
)
fsdp_args.add_argument(
"--fsdp_sharding_strategy",
type=int,
default=1,
type=str,
default="FULL_SHARD",
help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
Expand All @@ -503,6 +503,12 @@ def launch_command_parser(subparsers=None):
"--fsdp_backward_prefetch_policy",
default=None,
type=str,
help="This argument is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use `fsdp_backward_prefetch` instead.",
)
fsdp_args.add_argument(
"--fsdp_backward_prefetch",
default=None,
type=str,
help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
Expand Down
18 changes: 13 additions & 5 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import torch

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .imports import is_cuda_available, is_npu_available, is_xpu_available
from .versions import compare_versions
Expand Down Expand Up @@ -439,6 +439,7 @@ class CustomDtype(enum.Enum):
r"""
An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`.
"""

FP8 = "fp8"
INT4 = "int4"

Expand Down Expand Up @@ -918,7 +919,7 @@ class FullyShardedDataParallelPlugin:
},
)
limit_all_gathers: bool = field(
default=False,
default=True,
metadata={
"help": "If False, then FSDP allows the CPU thread to schedule all-gathers "
"without any extra synchronization. If True, then FSDP explicitly synchronizes the CPU thread to prevent "
Expand All @@ -929,9 +930,10 @@ class FullyShardedDataParallelPlugin:
use_orig_params: bool = field(
default=True,
metadata={
"help": "If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. "
"help": "If `True`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. "
"Useful in cases such as parameter-efficient fine-tuning. "
"Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019)"
"Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). "
"This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP."
},
)
param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field(
Expand Down Expand Up @@ -969,7 +971,13 @@ def __post_init__(self):

prefix = "FSDP_"
if self.sharding_strategy is None:
self.sharding_strategy = ShardingStrategy(int(os.environ.get(prefix + "SHARDING_STRATEGY", 1)))
sharding_strategy = os.environ.get(prefix + "SHARDING_STRATEGY", "FULL_SHARD")
sharding_strategy = (
FSDP_SHARDING_STRATEGY.index(sharding_strategy) + 1
if not sharding_strategy.isdigit()
else int(sharding_strategy)
)
self.sharding_strategy = ShardingStrategy(sharding_strategy)

if self.cpu_offload is None:
if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1:
Expand Down
10 changes: 9 additions & 1 deletion src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import os
import sys
import warnings
from ast import literal_eval
from typing import Any, Dict, List, Tuple

Expand Down Expand Up @@ -188,7 +189,14 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
if args.fsdp_transformer_layer_cls_to_wrap is not None:
current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap)
if args.fsdp_backward_prefetch_policy is not None:
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy)
warnings.warn(
"`fsdp_backward_prefetch_policy` is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use"
" `fsdp_backward_prefetch` instead",
FutureWarning,
)
args.fsdp_backward_prefetch = args.fsdp_backward_prefetch_policy
if args.fsdp_backward_prefetch is not None:
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch)
if args.fsdp_state_dict_type is not None:
current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)
current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower()
Expand Down
16 changes: 12 additions & 4 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,18 @@ def setUp(self):
def test_sharding_strategy(self):
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy

# check that giving enums works fine
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
env = self.dist_env.copy()
env["FSDP_SHARDING_STRATEGY"] = f"{i + 1}"
env["FSDP_SHARDING_STRATEGY_NAME"] = strategy
with mockenv_context(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
self.assertEqual(fsdp_plugin.sharding_strategy, ShardingStrategy(i + 1))

# check that giving names works fine
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
env = self.dist_env.copy()
env["FSDP_SHARDING_STRATEGY"] = strategy
with mockenv_context(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
self.assertEqual(fsdp_plugin.sharding_strategy, ShardingStrategy(i + 1))
Expand Down Expand Up @@ -201,7 +209,7 @@ def test_performance(self):
cmd_config = cmd.copy()
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
if strategy.lower() in config:
cmd_config.append(f"--fsdp_sharding_strategy={i+1}")
cmd_config.append(f"--fsdp_sharding_strategy={strategy}")
break

if "fp32" in config:
Expand Down Expand Up @@ -247,7 +255,7 @@ def test_checkpointing(self):

for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
cmd_config = cmd.copy()
cmd_config.append(f"--fsdp_sharding_strategy={i+1}")
cmd_config.append(f"--fsdp_sharding_strategy={strategy}")
if strategy != "FULL_SHARD":
continue
state_dict_config_index = len(cmd_config)
Expand Down Expand Up @@ -301,7 +309,7 @@ def test_peak_memory_usage(self):
cmd_config.extend(["--use_fsdp"])
for i, strategy in enumerate(FSDP_SHARDING_STRATEGY):
if strategy.lower() in spec:
cmd_config.append(f"--fsdp_sharding_strategy={i+1}")
cmd_config.append(f"--fsdp_sharding_strategy={strategy}")
break

if "cpu_offload" in spec:
Expand Down

0 comments on commit 848ed80

Please sign in to comment.