diff --git a/docs/source/usage_guides/fsdp.md b/docs/source/usage_guides/fsdp.md index 01ae00508b6..c1ed0415c85 100644 --- a/docs/source/usage_guides/fsdp.md +++ b/docs/source/usage_guides/fsdp.md @@ -46,10 +46,10 @@ downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch_policy: BACKWARD_PRE - fsdp_forward_prefetch: true + fsdp_forward_prefetch: false fsdp_cpu_ram_efficient_loading: true fsdp_offload_params: false - fsdp_sharding_strategy: 1 + fsdp_sharding_strategy: FULL_SHARD fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_sync_module_states: true fsdp_transformer_layer_cls_to_wrap: BertLayer diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index 85d13d19cc5..f6c1fc9ac33 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -327,8 +327,7 @@ def get_cluster_input(): fsdp_config["fsdp_sharding_strategy"] = _ask_options( sharding_strategy_query, FSDP_SHARDING_STRATEGY, - lambda x: int(x) + 1, - default=1, + lambda x: FSDP_SHARDING_STRATEGY[int(x)], ) fsdp_config["fsdp_offload_params"] = _ask_field( "Do you want to offload parameters and gradients to CPU? [yes/NO]: ", @@ -362,7 +361,7 @@ def get_cluster_input(): default=100000000, ) fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?" - fsdp_config["fsdp_backward_prefetch_policy"] = _ask_options( + fsdp_config["fsdp_backward_prefetch"] = _ask_options( fsdp_backward_prefetch_query, FSDP_BACKWARD_PREFETCH, lambda x: FSDP_BACKWARD_PREFETCH[int(x)], diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 66bbea21707..86038f425e5 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -482,8 +482,8 @@ def launch_command_parser(subparsers=None): ) fsdp_args.add_argument( "--fsdp_sharding_strategy", - type=int, - default=1, + type=str, + default="FULL_SHARD", help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).", ) fsdp_args.add_argument( @@ -503,6 +503,12 @@ def launch_command_parser(subparsers=None): "--fsdp_backward_prefetch_policy", default=None, type=str, + help="This argument is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use `fsdp_backward_prefetch` instead.", + ) + fsdp_args.add_argument( + "--fsdp_backward_prefetch", + default=None, + type=str, help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).", ) fsdp_args.add_argument( diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 917861eed64..9da98de72a2 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -30,7 +30,7 @@ import torch -from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE +from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE from .environment import str_to_bool from .imports import is_cuda_available, is_npu_available, is_xpu_available from .versions import compare_versions @@ -439,6 +439,7 @@ class CustomDtype(enum.Enum): r""" An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`. """ + FP8 = "fp8" INT4 = "int4" @@ -918,7 +919,7 @@ class FullyShardedDataParallelPlugin: }, ) limit_all_gathers: bool = field( - default=False, + default=True, metadata={ "help": "If False, then FSDP allows the CPU thread to schedule all-gathers " "without any extra synchronization. If True, then FSDP explicitly synchronizes the CPU thread to prevent " @@ -929,9 +930,10 @@ class FullyShardedDataParallelPlugin: use_orig_params: bool = field( default=True, metadata={ - "help": "If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. " + "help": "If `True`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. " "Useful in cases such as parameter-efficient fine-tuning. " - "Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019)" + "Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). " + "This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP." }, ) param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( @@ -969,7 +971,13 @@ def __post_init__(self): prefix = "FSDP_" if self.sharding_strategy is None: - self.sharding_strategy = ShardingStrategy(int(os.environ.get(prefix + "SHARDING_STRATEGY", 1))) + sharding_strategy = os.environ.get(prefix + "SHARDING_STRATEGY", "FULL_SHARD") + sharding_strategy = ( + FSDP_SHARDING_STRATEGY.index(sharding_strategy) + 1 + if not sharding_strategy.isdigit() + else int(sharding_strategy) + ) + self.sharding_strategy = ShardingStrategy(sharding_strategy) if self.cpu_offload is None: if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1: diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index a299343d90b..e748e371bc8 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -15,6 +15,7 @@ import argparse import os import sys +import warnings from ast import literal_eval from typing import Any, Dict, List, Tuple @@ -188,7 +189,14 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]: if args.fsdp_transformer_layer_cls_to_wrap is not None: current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap) if args.fsdp_backward_prefetch_policy is not None: - current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy) + warnings.warn( + "`fsdp_backward_prefetch_policy` is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use" + " `fsdp_backward_prefetch` instead", + FutureWarning, + ) + args.fsdp_backward_prefetch = args.fsdp_backward_prefetch_policy + if args.fsdp_backward_prefetch is not None: + current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch) if args.fsdp_state_dict_type is not None: current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type) current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower() diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index c494f5e2d2e..0f490c80f2c 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -69,10 +69,18 @@ def setUp(self): def test_sharding_strategy(self): from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy + # check that giving enums works fine for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): env = self.dist_env.copy() env["FSDP_SHARDING_STRATEGY"] = f"{i + 1}" - env["FSDP_SHARDING_STRATEGY_NAME"] = strategy + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + self.assertEqual(fsdp_plugin.sharding_strategy, ShardingStrategy(i + 1)) + + # check that giving names works fine + for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): + env = self.dist_env.copy() + env["FSDP_SHARDING_STRATEGY"] = strategy with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() self.assertEqual(fsdp_plugin.sharding_strategy, ShardingStrategy(i + 1)) @@ -201,7 +209,7 @@ def test_performance(self): cmd_config = cmd.copy() for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): if strategy.lower() in config: - cmd_config.append(f"--fsdp_sharding_strategy={i+1}") + cmd_config.append(f"--fsdp_sharding_strategy={strategy}") break if "fp32" in config: @@ -247,7 +255,7 @@ def test_checkpointing(self): for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): cmd_config = cmd.copy() - cmd_config.append(f"--fsdp_sharding_strategy={i+1}") + cmd_config.append(f"--fsdp_sharding_strategy={strategy}") if strategy != "FULL_SHARD": continue state_dict_config_index = len(cmd_config) @@ -301,7 +309,7 @@ def test_peak_memory_usage(self): cmd_config.extend(["--use_fsdp"]) for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): if strategy.lower() in spec: - cmd_config.append(f"--fsdp_sharding_strategy={i+1}") + cmd_config.append(f"--fsdp_sharding_strategy={strategy}") break if "cpu_offload" in spec: