From 386f7d2825409216e5ce4af0e3060d07f7d44914 Mon Sep 17 00:00:00 2001 From: huismiling Date: Thu, 1 Aug 2024 04:33:15 +0800 Subject: [PATCH] add MLU devices for rng state saving and loading. (#2940) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Cambricon MLU accelerator support * up mlu support for test * fix mlu device MULTI_MLU * Update src/accelerate/utils/imports.py it's beautiful ! Co-authored-by: Zach Mueller * up mlu for quality check * fix mlu device longTensor error * fix mlu device tensor dtype check * fix mlu device send_to_device with torch dynamo error * Refactor AcceleratorState * Should be near complete now * Last missing piece * Make my way to the acceleratorstate * Include update to global var * Don't use global * gpu -> cuda * Don't use update for dict, easier to read * Fix tests * stash * Getting closer... * Needed to spawn at the very end after env was setup * Explain set_device before deepspeed * Make docstring more accurate * Early return insteaD * Delineat blocks * Make prepare_backend return state + backend for clarity/less magic * fix mlu longtensor.to() bugs. * fix MLU devices rng state save and load. --------- Co-authored-by: Zach Mueller --- src/accelerate/checkpointing.py | 5 +++++ src/accelerate/commands/env.py | 2 ++ src/accelerate/utils/dataclasses.py | 4 +++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 185ba0e04c4..aebe6e1c77a 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -32,6 +32,7 @@ SCHEDULER_NAME, WEIGHTS_NAME, get_pretty_name, + is_mlu_available, is_torch_xla_available, is_xpu_available, save, @@ -143,6 +144,8 @@ def save_accelerator_state( states["torch_manual_seed"] = torch.get_rng_state() if is_xpu_available(): states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() + if is_mlu_available(): + states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all() else: states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() if is_torch_xla_available(): @@ -255,6 +258,8 @@ def load_accelerator_state( torch.set_rng_state(states["torch_manual_seed"]) if is_xpu_available(): torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) + if is_mlu_available(): + torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"]) else: torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) if is_torch_xla_available(): diff --git a/src/accelerate/commands/env.py b/src/accelerate/commands/env.py index 7078c6c0adc..7dd5995f6b4 100644 --- a/src/accelerate/commands/env.py +++ b/src/accelerate/commands/env.py @@ -81,6 +81,8 @@ def env_command(args): } if pt_cuda_available: info["GPU type"] = torch.cuda.get_device_name() + if pt_mlu_available: + info["MLU type"] = torch.mlu.get_device_name() if pt_npu_available: info["CANN version"] = torch.version.cann diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index b12fde45bcf..cf41bc76b62 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -32,7 +32,7 @@ from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE from .environment import str_to_bool -from .imports import is_cuda_available, is_npu_available, is_xpu_available +from .imports import is_cuda_available, is_mlu_available, is_npu_available, is_xpu_available from .versions import compare_versions @@ -1341,6 +1341,8 @@ def __post_init__(self): if self.sync_module_states: if is_npu_available(): device = torch.npu.current_device() + elif is_mlu_available(): + device = torch.mlu.current_device() elif is_cuda_available(): device = torch.cuda.current_device() elif is_xpu_available():