Skip to content

Commit 2d734d9

Browse files
laclouis5rwightman
authored andcommitted
Fixed unfused attn2d scale
1 parent 851e074 commit 2d734d9

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

Diff for: tests/test_layers.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import pytest
12
import torch
23
import torch.nn as nn
34

4-
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
5+
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d
56

67
import importlib
78
import os
@@ -119,3 +120,27 @@ def test_get_act_fn_none():
119120
assert get_act_fn(None) is None
120121
assert get_act_fn('') is None
121122

123+
124+
@pytest.mark.parametrize("bias", [True, False])
125+
@pytest.mark.parametrize("expand_first", [True, False])
126+
@pytest.mark.parametrize("head_first", [True, False])
127+
@pytest.mark.parametrize("attn_mask", [True, False])
128+
def test_attn2d(bias, expand_first, head_first, attn_mask):
129+
x = torch.randn(1, 128, 32, 48)
130+
attn = Attention2d(
131+
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
132+
)
133+
134+
if attn_mask:
135+
mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32)
136+
else:
137+
mask = None
138+
139+
o1 = attn(x, mask)
140+
attn.fused_attn = False
141+
o2 = attn(x, mask)
142+
143+
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"
144+
145+
146+

Diff for: timm/layers/attention2d.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def __init__(
312312
self.num_heads = num_heads
313313
self.dim_head = dim_attn // num_heads
314314
self.head_first = head_first
315-
self.scale = num_heads ** -0.5
316315
self.fused_attn = use_fused_attn()
317316

318317
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
@@ -337,14 +336,15 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
337336
dropout_p=self.attn_drop.p if self.training else 0.,
338337
).transpose(-1, -2).reshape(B, -1, H, W)
339338
else:
340-
q = q * self.scale
341-
attn = q.transpose(-2, -1) @ k
339+
q = q.transpose(-1, -2)
340+
v = v.transpose(-1, -2)
341+
attn = q @ k * q.size(-1) ** -0.5
342342
if attn_mask is not None:
343343
# NOTE: assumes mask is float and in correct shape
344344
attn = attn + attn_mask
345345
attn = attn.softmax(dim=-1)
346346
attn = self.attn_drop(attn)
347-
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
347+
x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)
348348

349349
x = self.proj(x)
350350
x = self.proj_drop(x)

0 commit comments

Comments
 (0)