Skip to content

Commit ea5b77a

Browse files
committed
self attention, more
Some related discussion in #81.
1 parent dea44d0 commit ea5b77a

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

nn/attention.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
from .. import nn
77

88

9-
class SelfAttention(nn.Module):
9+
# noinspection PyAbstractClass
10+
class SelfAttentionBase(nn.Module):
1011
"""
11-
Classic self attention
12+
Shared base class for self attention
1213
"""
1314
def __init__(self, *, key_dim_total: nn.Dim, value_dim_total: nn.Dim, num_heads: Union[int, nn.Dim],
1415
att_dropout: float = 0.):
1516
super().__init__()
16-
if not isinstance(num_heads, nn.Dim):
17-
num_heads = nn.SpatialDim("num_heads", num_heads)
1817
self.key_dim_total = key_dim_total
1918
self.key_dim_per_head = key_dim_total // num_heads
2019
self.value_dim_total = value_dim_total
@@ -26,6 +25,16 @@ def __init__(self, *, key_dim_total: nn.Dim, value_dim_total: nn.Dim, num_heads:
2625
self.expand_dim = nn.SpatialDim("self_att_expand_dim")
2726
self.att_dropout = att_dropout
2827

28+
29+
class SelfAttention(SelfAttentionBase):
30+
"""
31+
Classic self attention
32+
"""
33+
def __init__(self, *, key_dim_total: nn.Dim, value_dim_total: nn.Dim, num_heads: Union[int, nn.Dim],
34+
att_dropout: float = 0.):
35+
super().__init__(
36+
key_dim_total=key_dim_total, value_dim_total=value_dim_total, num_heads=num_heads, att_dropout=att_dropout)
37+
2938
def forward(self, source: nn.LayerRef, *, axis: nn.Dim) -> nn.Layer:
3039
"""forward"""
3140
# noinspection DuplicatedCode
@@ -49,21 +58,18 @@ def forward(self, source: nn.LayerRef, *, axis: nn.Dim) -> nn.Layer:
4958
return output
5059

5160

52-
class SelfAttentionStep(nn.Module):
61+
class CausalSelfAttention(SelfAttentionBase):
62+
pass # TODO
63+
64+
65+
class CausalSelfAttentionStep(SelfAttentionBase):
5366
"""
54-
Auto-regressive self-attention
67+
Causal auto-regressive self-attention
5568
"""
5669
def __init__(self, *, key_dim_total: nn.Dim, value_dim_total: nn.Dim, num_heads: Union[int, nn.Dim],
5770
att_dropout: float = 0.):
58-
super().__init__()
59-
self.key_dim_total = key_dim_total
60-
self.key_dim_per_head = key_dim_total // num_heads
61-
self.value_dim_total = value_dim_total
62-
self.value_dim_per_head = value_dim_total // num_heads
63-
self.num_heads = num_heads
64-
self.qkv = nn.Linear(key_dim_total * 2 + value_dim_total)
65-
self.expand_dim = nn.DimensionTag(kind=nn.DimensionTag.Types.Spatial, description="self_att_expand_dim")
66-
self.att_dropout = att_dropout
71+
super().__init__(
72+
key_dim_total=key_dim_total, value_dim_total=value_dim_total, num_heads=num_heads, att_dropout=att_dropout)
6773

6874
def forward(self, source: nn.LayerRef, *, state: nn.LayerState) -> Tuple[nn.Layer, nn.LayerState]:
6975
"""forward"""

0 commit comments

Comments
 (0)