Skip to content

Commit 6908d14

Browse files
kimishpatelkirklandsign
authored andcommitted
[Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask
Pull Request resolved: #10285 Previously we assumed that the custom sdpa always does causal attention. This diff adds option to this module swap pass to make custom sdpa leverage attention mask instead of causal. ghstack-source-id: 279292324 @exported-using-ghexport Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/)
1 parent 1f6d3c5 commit 6908d14

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

examples/models/llama/source_transformation/sdpa.py

+50-14
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module):
2222
def __init__(
2323
self,
2424
dim: int,
25+
max_context_len,
26+
enable_dynamic_shape,
27+
use_attention_mask: bool = False,
2528
):
2629
super().__init__()
2730
self.dim = dim
31+
self.max_context_len = max_context_len
32+
self.use_attention_mask = use_attention_mask
33+
self.enable_dynamic_shape = enable_dynamic_shape
2834

2935
def forward(
3036
self,
@@ -36,6 +42,16 @@ def forward(
3642
seqlen,
3743
mask,
3844
):
45+
if self.use_attention_mask:
46+
if self.enable_dynamic_shape:
47+
start_pos = input_pos[-1].item()
48+
torch._check_is_size(start_pos)
49+
torch._check(start_pos < self.max_context_len)
50+
seq_length = q.size(2)
51+
mask = mask.narrow(0, start_pos, seq_length)
52+
else:
53+
mask = mask[input_pos]
54+
3955
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
4056
k = k.transpose(1, 2)
4157
v = v.transpose(1, 2)
@@ -47,34 +63,54 @@ def forward(
4763
k = k.to(dtype=torch.float)
4864
v = v.to(dtype=torch.float)
4965

50-
output = torch.ops.llama.custom_sdpa(
51-
q,
52-
k,
53-
v,
54-
input_pos[0].item(),
55-
None, # Attention mask
56-
0, # dropout probability. Ignored by the code
57-
True, # is_causal
58-
)
66+
if self.use_attention_mask:
67+
output = torch.ops.llama.custom_sdpa(
68+
q,
69+
k,
70+
v,
71+
input_pos[0].item(),
72+
mask, # Attention mask
73+
0, # dropout probability. Ignored by the code
74+
False, # is_causal
75+
)
76+
else:
77+
output = torch.ops.llama.custom_sdpa(
78+
q,
79+
k,
80+
v,
81+
input_pos[0].item(),
82+
None, # Attention mask
83+
0, # dropout probability. Ignored by the code
84+
True, # is_causal
85+
)
5986
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
6087

6188

62-
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
89+
def _replace_sdpa_with_custom_op(
90+
module: torch.nn.Module, use_attention_mask: bool = False
91+
):
6392
for name, child in module.named_children():
6493
if isinstance(child, SDPA):
6594
setattr(
6695
module,
6796
name,
68-
SDPACustom(child.dim),
97+
SDPACustom(
98+
child.dim,
99+
child.max_context_len,
100+
child.enable_dynamic_shape,
101+
use_attention_mask=use_attention_mask,
102+
),
69103
)
70104
else:
71-
_replace_sdpa_with_custom_op(child)
105+
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask)
72106

73107

74-
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
108+
def replace_sdpa_with_custom_op(
109+
module: torch.nn.Module, use_attention_mask: bool = False
110+
) -> torch.nn.Module:
75111
from executorch.extension.llm.custom_ops import custom_ops # noqa
76112

77-
_replace_sdpa_with_custom_op(module)
113+
_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask)
78114
return module
79115

80116

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False):
7171
self.seq_len = 3
7272
self._init_cache()
7373
q, k_val, v_val = self._init_kv()
74-
self.float_sdpa = SDPACustom(self.dim)
75-
self.quantized_sdpa = SDPACustom(self.dim)
74+
self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True)
75+
self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True)
7676
k, v = self.custom_kv_cache.update(input_pos, k_val, v_val)
7777
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
7878
k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)

0 commit comments

Comments
 (0)