Skip to content

Context Parallel w/ Ring & Ulysses & Unified Attention #11941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d7b9e42
update
a-r-r-o-w Jul 14, 2025
7e97e43
update
a-r-r-o-w Jul 14, 2025
ecabd2a
add coauthor
a-r-r-o-w Jul 14, 2025
ff21b7f
improve test
a-r-r-o-w Jul 14, 2025
b8f7fe6
handle ip adapter params correctly
a-r-r-o-w Jul 14, 2025
17b678f
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 15, 2025
0cda91d
fix chroma qkv fusion test
a-r-r-o-w Jul 15, 2025
bc64f12
fix fastercache implementation
a-r-r-o-w Jul 15, 2025
a0b276d
fix more tests
a-r-r-o-w Jul 15, 2025
c141520
fight more tests
a-r-r-o-w Jul 15, 2025
4dcd672
add back set_attention_backend
a-r-r-o-w Jul 15, 2025
576da52
update
a-r-r-o-w Jul 15, 2025
e909b73
update
a-r-r-o-w Jul 15, 2025
1e7217f
make style
a-r-r-o-w Jul 15, 2025
4f52e34
make fix-copies
a-r-r-o-w Jul 15, 2025
d9c1683
make ip adapter processor compatible with attention dispatcher
a-r-r-o-w Jul 15, 2025
a73cb39
refactor chroma as well
a-r-r-o-w Jul 15, 2025
1e6b1c5
remove rmsnorm assert
a-r-r-o-w Jul 16, 2025
251bb61
minify and deprecate npu/xla processors
a-r-r-o-w Jul 16, 2025
84d2c84
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 16, 2025
51fed50
update
a-r-r-o-w Jul 16, 2025
9f37b87
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 16, 2025
7973626
refactor
a-r-r-o-w Jul 16, 2025
f859fdf
refactor; support flash attention 2 with cp
a-r-r-o-w Jul 16, 2025
e76fc94
fix
a-r-r-o-w Jul 16, 2025
171152f
support sage attention with cp
a-r-r-o-w Jul 16, 2025
62f164d
make torch compile compatible
a-r-r-o-w Jul 16, 2025
731b3bb
Merge branch 'to-single-file/flux' into attn-dispatcher-cp-and-training
a-r-r-o-w Jul 16, 2025
ff8ef45
Merge branch 'main' into attn-dispatcher-cp-and-training
a-r-r-o-w Jul 17, 2025
26a5a5c
update
a-r-r-o-w Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


if is_torch_available():
from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
Expand Down
299 changes: 299 additions & 0 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import dataclass
from typing import Dict, List, Type, Union

import torch
import torch.distributed._functional_collectives as funcol

from ..models._modeling_parallel import (
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
ParallelConfig,
)
from ..models.attention_dispatch import _parallel_context
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook


logger = get_logger(__name__) # pylint: disable=invalid-name

_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"


# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: Dict[str, int] = None
_cls: Type = None

def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}

if identifier in kwargs:
return kwargs[identifier], True, None

if self.cached_parameter_indices is not None:
index = self.cached_parameter_indices.get(identifier, None)
if index is None:
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
return args[index], False, index

if self._cls is None:
raise ValueError("Model class is not set for metadata.")

parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}

if identifier not in self.cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")

index = self.cached_parameter_indices[identifier]

if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")

return args[index], False, index


def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ParallelConfig,
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {parallel_config.cp_mesh} and plan: {plan}")

for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]

logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")

for m in submodule:
if isinstance(cp_model_plan, dict):
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
if isinstance(cp_model_plan, ContextParallelOutput):
cp_model_plan = [cp_model_plan]
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry = HookRegistry.check_if_exists_or_initialize(m)
registry.register_hook(hook, hook_name)

# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
# available in pytorch to set the parallel context before/after the forward/backward pass.
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
# The previous/older implementation simply did this:
# def new_forward(self, ...):
# with _parallel_context(parallel_config):
# return self.fn_ref.original_forward(*args, **kwargs)
# TODO: ask help from Pytorch team on how to improve this
@torch.compiler.disable
def forward_pre_hook(module, args):
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
module._diffusers_parallel_config_setter_context.__enter__()

@torch.compiler.disable
def forward_hook(module, args, output):
if module._diffusers_parallel_config_setter_context is not None:
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
module._diffusers_parallel_config_setter_context = None

@torch.compiler.disable
def backward_pre_hook(module, grad_output):
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
module._diffusers_parallel_config_setter_context.__enter__()

@torch.compiler.disable
def backward_hook(module, grad_output, grad_input):
if module._diffusers_parallel_config_setter_context is not None:
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
module._diffusers_parallel_config_setter_context = None

module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook)
module.register_full_backward_pre_hook(backward_pre_hook)
module.register_full_backward_hook(backward_hook)


class ContextParallelSplitHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
self.module_forward_metadata = None

def initialize_hook(self, module):
cls = unwrap_module(module).__class__
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
return module

def pre_forward(self, module, *args, **kwargs):
args_list = list(args)

for name, cpm in self.metadata.items():
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
continue

# Maybe the parameter was passed as a keyword argument
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
name, args_list, kwargs
)

if input_val is None:
continue

# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
)
sharded_input_val = []
for i, x in enumerate(input_val):
if torch.is_tensor(x) and not cpm[i].split_output:
x = self._prepare_cp_input(x, cpm[i])
sharded_input_val.append(x)
input_val = sharded_input_val
else:
raise ValueError(f"Unsupported input type: {type(input_val)}")

if is_kwarg:
kwargs[name] = input_val
elif index is not None and index < len(args_list):
args_list[index] = input_val
else:
raise ValueError(
f"An unexpected error occurred while processing the input '{name}'. Please open an "
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
f"example along with the full stack trace."
)

return tuple(args_list), kwargs

def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)

if not is_tensor and not is_tensor_list:
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")

output = [output] if is_tensor else list(output)
for index, cpm in self.metadata.items():
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
continue
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
output[index] = current_output

return output[0] if is_tensor else tuple(output)

def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)


class ContextParallelGatherHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config

def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)

if is_tensor:
output = [output]
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")

output = list(output)

if len(output) != len(self.metadata):
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")

for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)

return output[0] if is_tensor else tuple(output)


class EquipartitionSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
assert tensor.size()[dim] % mesh.size() == 0

# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]

return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]

@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
return tensor


def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
return _find_submodule_by_name(model, name)


def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name == "":
return model
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
if first_atom == "*":
if not isinstance(model, torch.nn.ModuleList):
raise ValueError("Wildcard '*' can only be used with ModuleList")
submodules = []
for submodule in model:
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
if not isinstance(subsubmodules, list):
subsubmodules = [subsubmodules]
submodules.extend(subsubmodules)
return submodules
else:
if hasattr(model, first_atom):
submodule = getattr(model, first_atom)
return _find_submodule_by_name(submodule, remaining_name)
else:
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
Loading
Loading