diff --git a/src/accelerate/state.py b/src/accelerate/state.py index c3a594de5d7..5ebad46c1be 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -695,12 +695,14 @@ def default_device(self) -> torch.device: return torch.device("mlu") elif is_musa_available(): return torch.device("musa") + # NPU should be checked before CUDA when using `transfer_to_npu` + # See issue #3020: https://github.com/huggingface/accelerate/issues/3020 + elif is_npu_available(): + return torch.device("npu") elif torch.cuda.is_available(): return torch.device("cuda") elif is_xpu_available(): return torch.device("xpu:0") - elif is_npu_available(): - return torch.device("npu") else: return torch.device("cpu") @@ -724,13 +726,15 @@ def _prepare_backend( elif is_musa_available(): backend = "mccl" distributed_type = DistributedType.MULTI_MUSA + # NPU should be checked before CUDA when using `transfer_to_npu` + # See issue #3020: https://github.com/huggingface/accelerate/issues/3020 + elif is_npu_available(): + backend = "hccl" + distributed_type = DistributedType.MULTI_NPU elif torch.cuda.is_available(): if backend is None: backend = "nccl" distributed_type = DistributedType.MULTI_GPU - elif is_npu_available(): - backend = "hccl" - distributed_type = DistributedType.MULTI_NPU if distributed_type is None and ( int(os.environ.get("LOCAL_RANK", -1)) != -1