6
6
from .. import nn
7
7
8
8
9
- class SelfAttention (nn .Module ):
9
+ # noinspection PyAbstractClass
10
+ class SelfAttentionBase (nn .Module ):
10
11
"""
11
- Classic self attention
12
+ Shared base class for self attention
12
13
"""
13
14
def __init__ (self , * , key_dim_total : nn .Dim , value_dim_total : nn .Dim , num_heads : Union [int , nn .Dim ],
14
15
att_dropout : float = 0. ):
15
16
super ().__init__ ()
16
- if not isinstance (num_heads , nn .Dim ):
17
- num_heads = nn .SpatialDim ("num_heads" , num_heads )
18
17
self .key_dim_total = key_dim_total
19
18
self .key_dim_per_head = key_dim_total // num_heads
20
19
self .value_dim_total = value_dim_total
@@ -26,6 +25,16 @@ def __init__(self, *, key_dim_total: nn.Dim, value_dim_total: nn.Dim, num_heads:
26
25
self .expand_dim = nn .SpatialDim ("self_att_expand_dim" )
27
26
self .att_dropout = att_dropout
28
27
28
+
29
+ class SelfAttention (SelfAttentionBase ):
30
+ """
31
+ Classic self attention
32
+ """
33
+ def __init__ (self , * , key_dim_total : nn .Dim , value_dim_total : nn .Dim , num_heads : Union [int , nn .Dim ],
34
+ att_dropout : float = 0. ):
35
+ super ().__init__ (
36
+ key_dim_total = key_dim_total , value_dim_total = value_dim_total , num_heads = num_heads , att_dropout = att_dropout )
37
+
29
38
def forward (self , source : nn .LayerRef , * , axis : nn .Dim ) -> nn .Layer :
30
39
"""forward"""
31
40
# noinspection DuplicatedCode
@@ -49,21 +58,18 @@ def forward(self, source: nn.LayerRef, *, axis: nn.Dim) -> nn.Layer:
49
58
return output
50
59
51
60
52
- class SelfAttentionStep (nn .Module ):
61
+ class CausalSelfAttention (SelfAttentionBase ):
62
+ pass # TODO
63
+
64
+
65
+ class CausalSelfAttentionStep (SelfAttentionBase ):
53
66
"""
54
- Auto -regressive self-attention
67
+ Causal auto -regressive self-attention
55
68
"""
56
69
def __init__ (self , * , key_dim_total : nn .Dim , value_dim_total : nn .Dim , num_heads : Union [int , nn .Dim ],
57
70
att_dropout : float = 0. ):
58
- super ().__init__ ()
59
- self .key_dim_total = key_dim_total
60
- self .key_dim_per_head = key_dim_total // num_heads
61
- self .value_dim_total = value_dim_total
62
- self .value_dim_per_head = value_dim_total // num_heads
63
- self .num_heads = num_heads
64
- self .qkv = nn .Linear (key_dim_total * 2 + value_dim_total )
65
- self .expand_dim = nn .DimensionTag (kind = nn .DimensionTag .Types .Spatial , description = "self_att_expand_dim" )
66
- self .att_dropout = att_dropout
71
+ super ().__init__ (
72
+ key_dim_total = key_dim_total , value_dim_total = value_dim_total , num_heads = num_heads , att_dropout = att_dropout )
67
73
68
74
def forward (self , source : nn .LayerRef , * , state : nn .LayerState ) -> Tuple [nn .Layer , nn .LayerState ]:
69
75
"""forward"""
0 commit comments