Skip to content

Add attention sink to flash attention #2070

@RissyRan

Description

@RissyRan

Hi team,

Could we add attention sink mechanism to flash attention, to enable gpt-oss model support?

I think we should add something like bellow after this qk_product attn_weights = jnp.einsum)... here.

Reference implementation here , and main logic:

sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
scores = probs[..., :-1]  # we drop the sink here

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions