diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index d112221dff2..7b6aa26f245 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1848,8 +1848,25 @@ def _prepare_deepspeed(self, *args): engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs) if compare_versions("deepspeed", ">=", "0.14.4") and self.state.dynamo_plugin.backend != DynamoBackend.NO: + from deepspeed.utils import z3_leaf_module, set_z3_leaf_modules + + if transformers_no_split_modules := getattr(model, "_no_split_modules", None): + leaf_cls = [] + for module in model.modules(): + if module.__class__ in leaf_cls: + continue + elif z3_leaf_module(module): + # have user-specificed z3 leaf module + leaf_cls.clear() + break + elif module.__class__.__name__ in transformers_no_split_modules: + leaf_cls.append(module.__class__) + if leaf_cls: + set_z3_leaf_modules(model, leaf_cls) + compile_kwargs = self.state.dynamo_plugin.to_kwargs() engine.compile(backend=compile_kwargs.pop("backend"), compile_kwargs=compile_kwargs) + if optimizer is not None: optimizer = DeepSpeedOptimizerWrapper(optimizer) if scheduler is not None: