Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...hooks.context_parallel import EquipartitionSharder
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
Expand Down Expand Up @@ -660,6 +661,15 @@ def forward(
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
if ts_seq_len is not None:
# Check if running under context parallel and split along seq_len dimension
if hasattr(self, '_parallel_config') and self._parallel_config is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when cp is enabled, seq_len is split, timestep_shape is [batch_size, seq_len, 6, inner_dim], so should be split in dim_1 as well since hidden state is split in seq_len dim as well. or else shape miss match will occur

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@sywangyi sywangyi Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean split timestep in forward? adding
"": {
"timestep": ContextParallelInput(split_dim=1, split_output=False)
}, to _cp_plan will make 5B work, but 14B fail since 5B timestep dims is 2. 14 timestep dims is 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is an interesting situation. To tackle these, I think we might have to revisit the ContextParallelInput and ContextParallelOutput definitions a bit more.

If we had a way to tell the partitioner that the input might have "dynamic" dimensions depending on the model configs (like in this case), and what it should do if that's the case, it might be more flexible as a solution.

@DN6 curious to know what you think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes a lot of things easier, for sure!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expand_timesteps is not passed in WanTransformer3DModel init, so, no way to judge if timestep.dim is 2 or 1 currently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a config if we want to
i want to hear @DN6 's thoughts on this first though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially, we can add cp_plan directly as a config, allow model owner to overridee it I think (in this case, we would send a PR into wan repo, i think it'd be ok)

It's also very in line with transformers does it btw.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu @DN6 we allow passing parallel_config through from_pretrained(), too. Wonder, if it could make sense to allow users to pass a custom _cp_plan through it.

cp_config = getattr(self._parallel_config, 'context_parallel_config', None)
if cp_config is not None and cp_config._world_size > 1:
timestep_proj = EquipartitionSharder.shard(
timestep_proj,
dim=1,
mesh=cp_config._flattened_mesh
)
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
Expand All @@ -681,6 +691,15 @@ def forward(

# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# Check if running under context parallel and split along seq_len dimension
if hasattr(self, '_parallel_config') and self._parallel_config is not None:
cp_config = getattr(self._parallel_config, 'context_parallel_config', None)
if cp_config is not None and cp_config._world_size > 1:
temb = EquipartitionSharder.shard(
temb,
dim=1,
mesh=cp_config._flattened_mesh
)
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
Expand Down