-
Notifications
You must be signed in to change notification settings - Fork 524
Description
Hi team,
Could we add attention sink mechanism to flash attention, to enable gpt-oss model support?
I think we should add something like bellow after this qk_product attn_weights = jnp.einsum)...
here.
Reference implementation here , and main logic:
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
scores = probs[..., :-1] # we drop the sink here
Metadata
Metadata
Assignees
Labels
No labels