Skip to content

Commit 6a29fe9

Browse files
authored
[Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask
Differential Revision: D73222736 Pull Request resolved: #10285
1 parent 1c6e332 commit 6a29fe9

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

examples/models/llama/source_transformation/sdpa.py

+50-14
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module):
2222
def __init__(
2323
self,
2424
dim: int,
25+
max_context_len,
26+
enable_dynamic_shape,
27+
use_attention_mask: bool = False,
2528
):
2629
super().__init__()
2730
self.dim = dim
31+
self.max_context_len = max_context_len
32+
self.use_attention_mask = use_attention_mask
33+
self.enable_dynamic_shape = enable_dynamic_shape
2834

2935
def forward(
3036
self,
@@ -36,6 +42,16 @@ def forward(
3642
seqlen,
3743
mask,
3844
):
45+
if self.use_attention_mask:
46+
if self.enable_dynamic_shape:
47+
start_pos = input_pos[-1].item()
48+
torch._check_is_size(start_pos)
49+
torch._check(start_pos < self.max_context_len)
50+
seq_length = q.size(2)
51+
mask = mask.narrow(0, start_pos, seq_length)
52+
else:
53+
mask = mask[input_pos]
54+
3955
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
4056
k = k.transpose(1, 2)
4157
v = v.transpose(1, 2)
@@ -47,34 +63,54 @@ def forward(
4763
k = k.to(dtype=torch.float)
4864
v = v.to(dtype=torch.float)
4965

50-
output = torch.ops.llama.custom_sdpa(
51-
q,
52-
k,
53-
v,
54-
input_pos[0].item(),
55-
None, # Attention mask
56-
0, # dropout probability. Ignored by the code
57-
True, # is_causal
58-
)
66+
if self.use_attention_mask:
67+
output = torch.ops.llama.custom_sdpa(
68+
q,
69+
k,
70+
v,
71+
input_pos[0].item(),
72+
mask, # Attention mask
73+
0, # dropout probability. Ignored by the code
74+
False, # is_causal
75+
)
76+
else:
77+
output = torch.ops.llama.custom_sdpa(
78+
q,
79+
k,
80+
v,
81+
input_pos[0].item(),
82+
None, # Attention mask
83+
0, # dropout probability. Ignored by the code
84+
True, # is_causal
85+
)
5986
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
6087

6188

62-
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
89+
def _replace_sdpa_with_custom_op(
90+
module: torch.nn.Module, use_attention_mask: bool = False
91+
):
6392
for name, child in module.named_children():
6493
if isinstance(child, SDPA):
6594
setattr(
6695
module,
6796
name,
68-
SDPACustom(child.dim),
97+
SDPACustom(
98+
child.dim,
99+
child.max_context_len,
100+
child.enable_dynamic_shape,
101+
use_attention_mask=use_attention_mask,
102+
),
69103
)
70104
else:
71-
_replace_sdpa_with_custom_op(child)
105+
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask)
72106

73107

74-
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
108+
def replace_sdpa_with_custom_op(
109+
module: torch.nn.Module, use_attention_mask: bool = False
110+
) -> torch.nn.Module:
75111
from executorch.extension.llm.custom_ops import custom_ops # noqa
76112

77-
_replace_sdpa_with_custom_op(module)
113+
_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask)
78114
return module
79115

80116

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False):
7171
self.seq_len = 3
7272
self._init_cache()
7373
q, k_val, v_val = self._init_kv()
74-
self.float_sdpa = SDPACustom(self.dim)
75-
self.quantized_sdpa = SDPACustom(self.dim)
74+
self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True)
75+
self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True)
7676
k, v = self.custom_kv_cache.update(input_pos, k_val, v_val)
7777
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
7878
k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)

0 commit comments

Comments
 (0)