Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,18 +1538,65 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
if _parallel_config is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, 'supports_context_parallel=True' should be also added to register @sywangyi @sayakpaul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, added

Copy link
Member

@sayakpaul sayakpaul Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DefTruth that should also fix #12446 (comment) right? Could you give this a check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DefTruth that should also fix #12446 (comment) right? Could you give this a check?

confirm fixed

query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
elif _parallel_config.context_parallel_config.ring_degree == 1:
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
world_size = _parallel_config.context_parallel_config.ulysses_degree
group = ulysses_mesh.get_group()

batch_size, seq_len_q_local, num_heads, head_dim = query.shape
_, seq_len_kv_local, _, _ = key.shape
num_heads_local = num_heads // world_size
query = (
query.reshape(batch_size, seq_len_q_local, world_size, num_heads_local, head_dim)
.permute(2, 1, 0, 3, 4)
.contiguous()
)
key = (
key.reshape(batch_size, seq_len_kv_local, world_size, num_heads_local, head_dim)
.permute(2, 1, 0, 3, 4)
.contiguous()
)
value = (
value.reshape(batch_size, seq_len_kv_local, world_size, num_heads_local, head_dim)
.permute(2, 1, 0, 3, 4)
.contiguous()
)
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = (
out.reshape(batch_size, num_heads_local, world_size, seq_len_q_local, head_dim)
.permute(2, 1, 0, 3, 4)
.contiguous()
)
out = _all_to_all_single(out, group)
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
else:
raise ValueError(
"Native attention backend does not support context parallelism with `ring_degree` > 1, try Ulysses Attention instead by specifying `ulysses_degree` > 1."
)
return out


Expand Down