Skip to content

Add generic support for Intel Gaudi accelerator (hpu device) #11328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
_is_valid_type,
is_accelerate_available,
is_accelerate_version,
is_hpu_available,
is_torch_npu_available,
is_torch_version,
is_transformers_version,
Expand Down Expand Up @@ -445,6 +446,20 @@ def module_is_offloaded(module):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)

# Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
os.environ["PT_HPU_GPU_MIGRATION"] = "1"
logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")

import habana_frameworks.torch # noqa: F401

# HPU hardware check
if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")

os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")

module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
is_gguf_version,
is_google_colab,
is_hf_hub_version,
is_hpu_available,
is_inflect_available,
is_invisible_watermark_available,
is_k_diffusion_available,
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ def is_timm_available():
return _timm_available


def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))


# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
Expand Down