Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gradient checkpointing #10611

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import safetensors
import torch
import torch.utils.checkpoint
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn
Expand Down Expand Up @@ -154,6 +155,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def __init__(self):
super().__init__()

self._gradient_checkpointing_func = None

def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
Expand All @@ -179,22 +182,55 @@ def is_gradient_checkpointing(self) -> bool:
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())

def enable_gradient_checkpointing(self) -> None:
def enable_gradient_checkpointing(
self,
gradient_checkpointing_func: Optional[Callable] = None,
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
"""
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
"""
if not self._supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
raise ValueError(
f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
f"`_supports_gradient_checkpointing` to `True` in the class definition."
)

user_provided_gradient_checkpointing_func = gradient_checkpointing_func is not None
if gradient_checkpointing_func is None:

def _gradient_checkpointing_func(module, *args):
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
return torch.utils.checkpoint.checkpoint(
module.__call__,
*args,
**ckpt_kwargs,
)

gradient_checkpointing_func = _gradient_checkpointing_func

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

if (
not user_provided_gradient_checkpointing_func
and is_torch_version(">=", "1.11.0")
and inspect.signature(gradient_checkpointing_func).parameters.get("use_reentrant") is not None
):
gradient_checkpointing_kwargs["use_reentrant"] = False

gradient_checkpointing_func = partial(gradient_checkpointing_func, **gradient_checkpointing_kwargs)

self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)

def disable_gradient_checkpointing(self) -> None:
"""
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
"""
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
self._set_gradient_checkpointing(enable=False)

def set_use_npu_flash_attention(self, valid: bool) -> None:
r"""
Expand Down Expand Up @@ -1354,6 +1390,24 @@ def get_memory_footprint(self, return_buffers=True):
mem = mem + mem_bufs
return mem

def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
) -> None:
is_gradient_checkpointing_set = False

for name, module in self.named_modules():
if hasattr(module, "gradient_checkpointing"):
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True

if not is_gradient_checkpointing_set:
raise ValueError(
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to use a module that supports gradient checkpointing "
f"by creating a boolean attribute `gradient_checkpointing` in the module and setting it to `True`."
)

def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = []

Expand Down
22 changes: 3 additions & 19 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
Expand Down Expand Up @@ -360,10 +360,6 @@ def __init__(

self.gradient_checkpointing = False

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -416,25 +412,13 @@ def forward(

for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
Expand Down
Loading