From deb7edb3a4ee92f3c5e0344fcb8657d646df5628 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 30 Oct 2025 13:07:11 +0800 Subject: [PATCH] fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled Signed-off-by: Wang, Yi --- .../models/transformers/transformer_wan.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f1a..38ba7d64c424 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...hooks.context_parallel import EquipartitionSharder from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph @@ -660,6 +661,15 @@ def forward( timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len ) if ts_seq_len is not None: + # Check if running under context parallel and split along seq_len dimension + if hasattr(self, '_parallel_config') and self._parallel_config is not None: + cp_config = getattr(self._parallel_config, 'context_parallel_config', None) + if cp_config is not None and cp_config._world_size > 1: + timestep_proj = EquipartitionSharder.shard( + timestep_proj, + dim=1, + mesh=cp_config._flattened_mesh + ) # batch_size, seq_len, 6, inner_dim timestep_proj = timestep_proj.unflatten(2, (6, -1)) else: @@ -681,6 +691,15 @@ def forward( # 5. Output norm, projection & unpatchify if temb.ndim == 3: + # Check if running under context parallel and split along seq_len dimension + if hasattr(self, '_parallel_config') and self._parallel_config is not None: + cp_config = getattr(self._parallel_config, 'context_parallel_config', None) + if cp_config is not None and cp_config._world_size > 1: + temb = EquipartitionSharder.shard( + temb, + dim=1, + mesh=cp_config._flattened_mesh + ) # batch_size, seq_len, inner_dim (wan 2.2 ti2v) shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) shift = shift.squeeze(2)