Skip to content

Commit d23facd

Browse files
authoredJan 2, 2025
Merge pull request #2388 from laclouis5/fix-mqa-v2
Fix MQA V2
2 parents 2d734d9 + 2d5277e commit d23facd

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed
 

Diff for: ‎tests/test_layers.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44

5-
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d
5+
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2
66

77
import importlib
88
import os
@@ -121,6 +121,23 @@ def test_get_act_fn_none():
121121
assert get_act_fn('') is None
122122

123123

124+
@pytest.mark.parametrize("dim", [128])
125+
@pytest.mark.parametrize("dim_out", [128, 256])
126+
@pytest.mark.parametrize("use_m", [True, False])
127+
def test_mqa_v2(dim, dim_out, use_m):
128+
mqa = MultiQueryAttentionV2(dim, dim_out)
129+
130+
x = torch.randn(1, dim, 32, 48)
131+
if use_m:
132+
m = torch.randn(1, dim, 16, 24)
133+
else:
134+
m = None
135+
136+
y = mqa(x, m=m)
137+
138+
assert (y.shape) == (1, dim_out, 32, 48)
139+
140+
124141
@pytest.mark.parametrize("bias", [True, False])
125142
@pytest.mark.parametrize("expand_first", [True, False])
126143
@pytest.mark.parametrize("head_first", [True, False])
@@ -141,6 +158,3 @@ def test_attn2d(bias, expand_first, head_first, attn_mask):
141158
o2 = attn(x, mask)
142159

143160
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"
144-
145-
146-

Diff for: ‎timm/layers/attention2d.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,24 @@ def _reshape_input(self, t):
5959

6060
def forward(self, x, m: Optional[torch.Tensor] = None):
6161
"""Run layer computation."""
62-
s = x.shape
63-
m = m or x
62+
b, _, h, w = x.shape
63+
m = m if m is not None else x
6464

6565
reshaped_x = self._reshape_input(x)
6666
reshaped_m = self._reshape_input(m)
6767

6868
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
6969
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
7070

71-
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
71+
attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
7272
attn = attn.softmax(dim=-1)
7373
attn = self.attn_drop(attn)
7474

7575
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
7676
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
77-
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
77+
result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
7878
result = self.proj_drop(result)
79-
return result.reshape(s)
79+
return result.reshape(b, -1, h, w)
8080

8181

8282
class MultiQueryAttention2d(nn.Module):

0 commit comments

Comments
 (0)