Skip to content

Commit af3732e

Browse files
committed
Add coord attn and some variants that I had lying around from old experiments.
1 parent ae4d1bb commit af3732e

File tree

2 files changed

+350
-0
lines changed

2 files changed

+350
-0
lines changed

timm/layers/coord_attn.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
""" Coordinate Attention and Variants
2+
3+
Coordinate Attention decomposes channel attention into two 1D feature encoding processes
4+
to capture long-range dependencies with precise positional information. This module includes
5+
the original implementation along with simplified and other variants.
6+
7+
Papers / References:
8+
- Coordinate Attention: `Coordinate Attention for Efficient Mobile Network Design` - https://arxiv.org/abs/2103.02907
9+
- Efficient Local Attention: `Rethinking Local Perception in Lightweight Vision Transformer` - https://arxiv.org/abs/2403.01123
10+
11+
Hacked together by / Copyright 2025 Ross Wightman
12+
"""
13+
from typing import Optional, Type, Union
14+
15+
import torch
16+
from torch import nn
17+
18+
from .create_act import create_act_layer
19+
from .helpers import make_divisible
20+
from .norm import GroupNorm1
21+
22+
23+
class CoordAttn(nn.Module):
24+
def __init__(
25+
self,
26+
channels: int,
27+
rd_ratio: float = 1. / 16,
28+
rd_channels: Optional[int] = None,
29+
rd_divisor: int = 8,
30+
se_factor: float = 2/3,
31+
bias: bool = False,
32+
act_layer: Type[nn.Module] = nn.Hardswish,
33+
norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
34+
gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
35+
has_skip: bool = False,
36+
device=None,
37+
dtype=None,
38+
):
39+
"""Coordinate Attention module for spatial feature recalibration.
40+
41+
Introduced in "Coordinate Attention for Efficient Mobile Network Design" (CVPR 2021).
42+
Decomposes channel attention into two 1D feature encoding processes along the height and
43+
width axes to capture long-range dependencies with precise positional information.
44+
45+
Args:
46+
channels: Number of input channels.
47+
rd_ratio: Reduction ratio for bottleneck channel calculation.
48+
rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set.
49+
rd_divisor: Divisor for making bottleneck channels divisible.
50+
se_factor: Applied to rd_ratio for final channel count (keeps params similar to SE).
51+
bias: Whether to use bias in convolution layers.
52+
act_layer: Activation module class for bottleneck.
53+
norm_layer: Normalization module class, None for no normalization.
54+
gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
55+
has_skip: Whether to add residual skip connection to output.
56+
device: Device to place tensors on.
57+
dtype: Data type for tensors.
58+
"""
59+
60+
dd = {'device': device, 'dtype': dtype}
61+
super().__init__()
62+
self.has_skip = has_skip
63+
if not rd_channels:
64+
rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.)
65+
66+
self.conv1 = nn.Conv2d(channels, rd_channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
67+
self.bn1 = norm_layer(rd_channels, **dd) if norm_layer is not None else nn.Identity()
68+
self.act = act_layer()
69+
70+
self.conv_h = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
71+
self.conv_w = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd)
72+
self.gate = create_act_layer(gate_layer)
73+
74+
def forward(self, x):
75+
identity = x
76+
77+
N, C, H, W = x.size()
78+
79+
# Strip pooling
80+
x_h = x.mean(3, keepdim=True)
81+
x_w = x.mean(2, keepdim=True)
82+
83+
x_w = x_w.transpose(-1, -2)
84+
y = torch.cat([x_h, x_w], dim=2)
85+
y = self.conv1(y)
86+
y = self.bn1(y)
87+
y = self.act(y)
88+
x_h, x_w = torch.split(y, [H, W], dim=2)
89+
x_w = x_w.transpose(-1, -2)
90+
91+
a_h = self.gate(self.conv_h(x_h))
92+
a_w = self.gate(self.conv_w(x_w))
93+
94+
out = identity * a_w * a_h
95+
if self.has_skip:
96+
out = out + identity
97+
98+
return out
99+
100+
101+
class SimpleCoordAttn(nn.Module):
102+
"""Simplified Coordinate Attention variant.
103+
104+
Uses
105+
* linear layers instead of convolutions
106+
* no norm
107+
* additive pre-gating re-combination
108+
for reduced complexity while maintaining the core coordinate attention mechanism
109+
of separate height and width attention.
110+
"""
111+
112+
def __init__(
113+
self,
114+
channels: int,
115+
rd_ratio: float = 0.25,
116+
rd_channels: Optional[int] = None,
117+
rd_divisor: int = 8,
118+
se_factor: float = 2 / 3,
119+
bias: bool = True,
120+
act_layer: Type[nn.Module] = nn.SiLU,
121+
gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
122+
has_skip: bool = False,
123+
device=None,
124+
dtype=None,
125+
):
126+
"""
127+
Args:
128+
channels: Number of input channels.
129+
rd_ratio: Reduction ratio for bottleneck channel calculation.
130+
rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set.
131+
rd_divisor: Divisor for making bottleneck channels divisible.
132+
se_factor: Applied to rd_ratio for final channel count (keeps param similar to SE)
133+
bias: Whether to use bias in linear layers.
134+
act_layer: Activation module class for bottleneck.
135+
gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
136+
has_skip: Whether to add residual skip connection to output.
137+
device: Device to place tensors on.
138+
dtype: Data type for tensors.
139+
"""
140+
dd = {'device': device, 'dtype': dtype}
141+
super().__init__()
142+
self.has_skip = has_skip
143+
144+
if not rd_channels:
145+
rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.)
146+
147+
self.fc1 = nn.Linear(channels, rd_channels, bias=bias, **dd)
148+
self.act = act_layer()
149+
self.fc_h = nn.Linear(rd_channels, channels, bias=bias, **dd)
150+
self.fc_w = nn.Linear(rd_channels, channels, bias=bias, **dd)
151+
152+
self.gate = create_act_layer(gate_layer)
153+
154+
def forward(self, x):
155+
identity = x
156+
157+
# Strip pooling
158+
x_h = x.mean(dim=3) # (N, C, H)
159+
x_w = x.mean(dim=2) # (N, C, W)
160+
161+
# Shared bottleneck projection
162+
x_h = self.act(self.fc1(x_h.transpose(1, 2))) # (N, H, rd_c)
163+
x_w = self.act(self.fc1(x_w.transpose(1, 2))) # (N, W, rd_c)
164+
165+
# Separate attention heads
166+
a_h = self.fc_h(x_h).transpose(1, 2).unsqueeze(-1) # (N, C, H, 1)
167+
a_w = self.fc_w(x_w).transpose(1, 2).unsqueeze(-2) # (N, C, 1, W)
168+
169+
out = identity * self.gate(a_h + a_w)
170+
if self.has_skip:
171+
out = out + identity
172+
173+
return out
174+
175+
176+
class EfficientLocalAttn(nn.Module):
177+
"""Efficient Local Attention.
178+
179+
Lightweight alternative to Coordinate Attention that preserves spatial
180+
information without channel reduction. Uses 1D depthwise convolutions
181+
and GroupNorm for better generalization.
182+
183+
Paper: https://arxiv.org/abs/2403.01123
184+
"""
185+
186+
def __init__(
187+
self,
188+
channels: int,
189+
kernel_size: int = 7,
190+
bias: bool = False,
191+
act_layer: Type[nn.Module] = nn.SiLU,
192+
gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
193+
norm_layer: Optional[Type[nn.Module]] = GroupNorm1,
194+
has_skip: bool = False,
195+
device=None,
196+
dtype=None,
197+
):
198+
"""
199+
Args:
200+
channels: Number of input channels.
201+
kernel_size: Kernel size for 1D depthwise convolutions.
202+
bias: Whether to use bias in convolution layers.
203+
act_layer: Activation module class applied after normalization.
204+
gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
205+
norm_layer: Normalization module class, None for no normalization.
206+
has_skip: Whether to add residual skip connection to output.
207+
device: Device to place tensors on.
208+
dtype: Data type for tensors.
209+
"""
210+
dd = {'device': device, 'dtype': dtype}
211+
super().__init__()
212+
self.has_skip = has_skip
213+
214+
self.conv_h = nn.Conv2d(
215+
channels, channels,
216+
kernel_size=(kernel_size, 1),
217+
stride=1,
218+
padding=(kernel_size // 2, 0),
219+
groups=channels,
220+
bias=bias,
221+
**dd
222+
)
223+
self.conv_w = nn.Conv2d(
224+
channels, channels,
225+
kernel_size=(1, kernel_size),
226+
stride=1,
227+
padding=(0, kernel_size // 2),
228+
groups=channels,
229+
bias=bias,
230+
**dd
231+
)
232+
if norm_layer is not None:
233+
self.norm_h = norm_layer(channels, **dd)
234+
self.norm_w = norm_layer(channels, **dd)
235+
else:
236+
self.norm_h = nn.Identity()
237+
self.norm_w = nn.Identity()
238+
self.act = act_layer()
239+
self.gate = create_act_layer(gate_layer)
240+
241+
def forward(self, x):
242+
identity = x
243+
244+
# Strip pooling: (N, C, H, W) -> (N, C, H) and (N, C, W)
245+
x_h = x.mean(dim=3, keepdim=True)
246+
x_w = x.mean(dim=2, keepdim=True)
247+
248+
# 1D conv + norm + act
249+
x_h = self.act(self.norm_h(self.conv_h(x_h))) # (N, C, H, 1)
250+
x_w = self.act(self.norm_w(self.conv_w(x_w))) # (N, C, 1, W)
251+
252+
# Generate attention maps
253+
a_h = self.gate(x_h) # (N, C, H, 1)
254+
a_w = self.gate(x_w) # (N, C, 1, W)
255+
256+
out = identity * a_h * a_w
257+
if self.has_skip:
258+
out = out + identity
259+
260+
return out
261+
262+
263+
class StripAttn(nn.Module):
264+
"""Minimal Strip Attention.
265+
266+
Lightweight spatial attention using strip pooling with optional learned refinement.
267+
"""
268+
269+
def __init__(
270+
self,
271+
channels: int,
272+
use_conv: bool = True,
273+
kernel_size: int = 3,
274+
bias: bool = False,
275+
gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
276+
has_skip: bool = False,
277+
device=None,
278+
dtype=None,
279+
**_,
280+
):
281+
"""
282+
Args:
283+
channels: Number of input channels.
284+
use_conv: Whether to apply depthwise convolutions for learned spatial refinement.
285+
kernel_size: Kernel size for 1D depthwise convolutions when use_conv is True.
286+
bias: Whether to use bias in convolution layers.
287+
gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class.
288+
has_skip: Whether to add residual skip connection to output.
289+
device: Device to place tensors on.
290+
dtype: Data type for tensors.
291+
"""
292+
dd = {'device': device, 'dtype': dtype}
293+
super().__init__()
294+
self.has_skip = has_skip
295+
self.use_conv = use_conv
296+
297+
if use_conv:
298+
self.conv_h = nn.Conv2d(
299+
channels, channels,
300+
kernel_size=(kernel_size, 1),
301+
stride=1,
302+
padding=(kernel_size // 2, 0),
303+
groups=channels,
304+
bias=bias,
305+
**dd
306+
)
307+
self.conv_w = nn.Conv2d(
308+
channels, channels,
309+
kernel_size=(1, kernel_size),
310+
stride=1,
311+
padding=(0, kernel_size // 2),
312+
groups=channels,
313+
bias=bias,
314+
**dd
315+
)
316+
else:
317+
self.conv_h = nn.Identity()
318+
self.conv_w = nn.Identity()
319+
320+
self.gate = create_act_layer(gate_layer)
321+
322+
def forward(self, x):
323+
identity = x
324+
325+
# Strip pooling
326+
x_h = x.mean(dim=3, keepdim=True) # (N, C, H, 1)
327+
x_w = x.mean(dim=2, keepdim=True) # (N, C, 1, W)
328+
329+
# Optional learned refinement
330+
x_h = self.conv_h(x_h)
331+
x_w = self.conv_w(x_w)
332+
333+
# Combine and gate
334+
a_hw = self.gate(x_h + x_w) # broadcasts to (N, C, H, W)
335+
336+
out = identity * a_hw
337+
if self.has_skip:
338+
out = out + identity
339+
340+
return out
341+

timm/layers/create_attn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .bottleneck_attn import BottleneckAttn
99
from .cbam import CbamModule, LightCbamModule
10+
from .coord_attn import CoordAttn, EfficientLocalAttn, StripAttn, SimpleCoordAttn
1011
from .eca import EcaModule, CecaModule
1112
from .gather_excite import GatherExcite
1213
from .global_context import GlobalContext
@@ -47,6 +48,14 @@ def get_attn(attn_type):
4748
module_cls = CbamModule
4849
elif attn_type == 'lcbam':
4950
module_cls = LightCbamModule
51+
elif attn_type == 'coord':
52+
module_cls = CoordAttn
53+
elif attn_type == 'scoord':
54+
module_cls = SimpleCoordAttn
55+
elif attn_type == 'ela':
56+
module_cls = EfficientLocalAttn
57+
elif attn_type == 'strip':
58+
module_cls = StripAttn
5059

5160
# Attention / attention-like modules w/ significant params
5261
# Typically replace some of the existing workhorse convs in a network architecture.

0 commit comments

Comments
 (0)