Causal mask in Mosaic GPU #27765
Replies: 2 comments 1 reply
-
Hi @apaszke, |
Beta Was this translation helpful? Give feedback.
-
Sorry I missed this entirely! |
Beta Was this translation helpful? Give feedback.
-
Hi @apaszke, |
Beta Was this translation helpful? Give feedback.
-
Sorry I missed this entirely! |
Beta Was this translation helpful? Give feedback.
-
Hi team,
I'm exploring Mosaic GPU DSL and it looks really promising, great work! I'm currently going through the attention_mgpu example and wanted to ask: is there a good way to add a causal mask there?
For example, I’d like to reproduce the behaviour of causal_mask from attention on pallas:
I tried this approach, but encountered an unimplemented error due to the use of iota (invoked via jnp.arange):
Is there an alternative or recommended way to implement causal masking in Mosaic GPU?
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions