Skip to content

Commit

Permalink
add MLU devices for rng state saving and loading. (#2940)
Browse files Browse the repository at this point in the history
* Add Cambricon MLU accelerator support

* up mlu support for test

* fix mlu device MULTI_MLU

* Update src/accelerate/utils/imports.py

it's beautiful !

Co-authored-by: Zach Mueller <[email protected]>

* up mlu for quality check

* fix mlu device longTensor error

* fix mlu device tensor dtype check

* fix mlu device send_to_device with torch dynamo error

* Refactor AcceleratorState

* Should be near complete now

* Last missing piece

* Make my way to the acceleratorstate

* Include update to global var

* Don't use global

* gpu -> cuda

* Don't use update for dict, easier to read

* Fix tests

* stash

* Getting closer...

* Needed to spawn at the very end after env was setup

* Explain set_device before deepspeed

* Make docstring more accurate

* Early return insteaD

* Delineat blocks

* Make prepare_backend return state + backend for clarity/less magic

* fix mlu longtensor.to() bugs.

* fix MLU devices rng state save and load.

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
huismiling and muellerzr authored Jul 31, 2024
1 parent 308a8e9 commit 386f7d2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SCHEDULER_NAME,
WEIGHTS_NAME,
get_pretty_name,
is_mlu_available,
is_torch_xla_available,
is_xpu_available,
save,
Expand Down Expand Up @@ -143,6 +144,8 @@ def save_accelerator_state(
states["torch_manual_seed"] = torch.get_rng_state()
if is_xpu_available():
states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
if is_mlu_available():
states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
else:
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
if is_torch_xla_available():
Expand Down Expand Up @@ -255,6 +258,8 @@ def load_accelerator_state(
torch.set_rng_state(states["torch_manual_seed"])
if is_xpu_available():
torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
if is_mlu_available():
torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
else:
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
if is_torch_xla_available():
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/commands/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def env_command(args):
}
if pt_cuda_available:
info["GPU type"] = torch.cuda.get_device_name()
if pt_mlu_available:
info["MLU type"] = torch.mlu.get_device_name()
if pt_npu_available:
info["CANN version"] = torch.version.cann

Expand Down
4 changes: 3 additions & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .imports import is_cuda_available, is_npu_available, is_xpu_available
from .imports import is_cuda_available, is_mlu_available, is_npu_available, is_xpu_available
from .versions import compare_versions


Expand Down Expand Up @@ -1341,6 +1341,8 @@ def __post_init__(self):
if self.sync_module_states:
if is_npu_available():
device = torch.npu.current_device()
elif is_mlu_available():
device = torch.mlu.current_device()
elif is_cuda_available():
device = torch.cuda.current_device()
elif is_xpu_available():
Expand Down

0 comments on commit 386f7d2

Please sign in to comment.