@@ -132,6 +132,7 @@ def get_fsdp_config(self):
132132 from functools import partial
133133
134134 # Third Party
135+ from accelerate .utils import FullyShardedDataParallelPlugin
135136 from peft .utils .other import fsdp_auto_wrap_policy
136137 from torch .distributed .fsdp import BackwardPrefetch , ShardingStrategy
137138 from torch .distributed .fsdp .fully_sharded_data_parallel import CPUOffload
@@ -159,27 +160,17 @@ def get_fsdp_config(self):
159160 prefetch_policy = (
160161 BackwardPrefetch .BACKWARD_POST if is_lora else BackwardPrefetch .BACKWARD_PRE
161162 )
163+ fsdp_plugin = FullyShardedDataParallelPlugin (
164+ auto_wrap_policy = wrap_policy ,
165+ limit_all_gathers = True ,
166+ backward_prefetch = prefetch_policy ,
167+ sharding_strategy = ShardingStrategy [self .fsdp_sharding_strategy ],
168+ cpu_offload = CPUOffload (self .fsdp_cpu_offload_params ),
169+ )
162170
163171 if self .device_str == "hpu" :
164- from optimum .habana .accelerate .utils import GaudiFullyShardedDataParallelPlugin
165- fsdp_plugin = GaudiFullyShardedDataParallelPlugin (
166- auto_wrap_policy = wrap_policy ,
167- limit_all_gathers = True ,
168- backward_prefetch = prefetch_policy ,
169- sharding_strategy = ShardingStrategy [self .fsdp_sharding_strategy ],
170- cpu_offload = CPUOffload (self .fsdp_cpu_offload_params ),
171- )
172172 fsdp_plugin .use_orig_params = True
173173 fsdp_plugin .sync_module_states = True
174- else :
175- from accelerate .utils import FullyShardedDataParallelPlugin
176- fsdp_plugin = FullyShardedDataParallelPlugin (
177- auto_wrap_policy = wrap_policy ,
178- limit_all_gathers = True ,
179- backward_prefetch = prefetch_policy ,
180- sharding_strategy = ShardingStrategy [self .fsdp_sharding_strategy ],
181- cpu_offload = CPUOffload (self .fsdp_cpu_offload_params ),
182- )
183174
184175 # `use_orig_params` must be disabled when using LoRA and FSDP together
185176 # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
0 commit comments