4
4
"""
5
5
6
6
from typing import Tuple , List , Union
7
+ from .. import nn
8
+ from . import LayerRef
7
9
8
- from . import Module , ModuleList , LayerRef , Linear , dropout , layer_norm , batch_norm , Conv , swish , glu , split_dims , \
9
- merge_dims , pool
10
10
11
-
12
- class _PositionwiseFeedForward (Module ):
11
+ class _PositionwiseFeedForward (nn .Module ):
13
12
"""
14
13
Conformer position-wise feedforward neural network layer
15
14
FF -> Activation -> Dropout -> FF
@@ -28,14 +27,14 @@ def __init__(self, d_model: int, d_ff: int, dropout: float, activation, l2: floa
28
27
self .dropout = dropout
29
28
self .activation = activation
30
29
31
- self .linear1 = Linear (n_out = d_ff , l2 = l2 )
32
- self .linear2 = Linear (n_out = d_model , l2 = l2 )
30
+ self .linear1 = nn . Linear (n_out = d_ff , l2 = l2 )
31
+ self .linear2 = nn . Linear (n_out = d_model , l2 = l2 )
33
32
34
33
def forward (self , inp : LayerRef ) -> LayerRef :
35
- return self .linear2 (dropout (self .activation (self .linear1 (inp )), dropout = self .dropout ))
34
+ return self .linear2 (nn . dropout (self .activation (self .linear1 (inp )), dropout = self .dropout ))
36
35
37
36
38
- class _ConformerConvBlock (Module ):
37
+ class _ConformerConvBlock (nn . Module ):
39
38
"""
40
39
Conformer convolution block
41
40
FF -> GLU -> depthwise conv -> BN -> Swish -> FF
@@ -49,21 +48,21 @@ def __init__(self, d_model: int, kernel_size: Tuple[int], l2: float = 0.0):
49
48
"""
50
49
super ().__init__ ()
51
50
52
- self .positionwise_conv1 = Linear (n_out = d_model * 2 , l2 = l2 )
53
- self .depthwise_conv = Conv (n_out = d_model , filter_size = kernel_size , groups = d_model , l2 = l2 , padding = 'same' )
54
- self .positionwise_conv2 = Linear (n_out = d_model , l2 = l2 )
51
+ self .positionwise_conv1 = nn . Linear (n_out = d_model * 2 , l2 = l2 )
52
+ self .depthwise_conv = nn . Conv (n_out = d_model , filter_size = kernel_size , groups = d_model , l2 = l2 , padding = 'same' )
53
+ self .positionwise_conv2 = nn . Linear (n_out = d_model , l2 = l2 )
55
54
56
55
def forward (self , inp : LayerRef ) -> LayerRef :
57
56
x_conv1 = self .positionwise_conv1 (inp )
58
- x_act = glu (x_conv1 )
57
+ x_act = nn . glu (x_conv1 )
59
58
x_depthwise_conv = self .depthwise_conv (x_act )
60
- x_bn = batch_norm (x_depthwise_conv )
61
- x_swish = swish (x_bn )
59
+ x_bn = nn . batch_norm (x_depthwise_conv )
60
+ x_swish = nn . swish (x_bn )
62
61
x_conv2 = self .positionwise_conv2 (x_swish )
63
62
return x_conv2
64
63
65
64
66
- class _ConformerConvSubsampleLayer (Module ):
65
+ class _ConformerConvSubsampleLayer (nn . Module ):
67
66
"""
68
67
Conv 2D block with optional max-pooling
69
68
"""
@@ -85,24 +84,24 @@ def __init__(self, filter_sizes: List[Tuple[int, ...]], pool_sizes: Union[List[T
85
84
self .dropout = dropout
86
85
self .pool_sizes = pool_sizes
87
86
88
- self .conv_layers = ModuleList ()
87
+ self .conv_layers = nn . ModuleList ()
89
88
for filter_size , channel_size in zip (filter_sizes , channel_sizes ):
90
89
self .conv_layers .append (
91
- Conv (l2 = l2 , activation = act , filter_size = filter_size , n_out = channel_size , padding = padding ))
90
+ nn . Conv (l2 = l2 , activation = act , filter_size = filter_size , n_out = channel_size , padding = padding ))
92
91
93
92
def forward (self , inp : LayerRef ) -> LayerRef :
94
- x = split_dims (inp , axis = 'F' , dims = (- 1 , 1 ))
93
+ x = nn . split_dims (inp , axis = 'F' , dims = (- 1 , 1 ))
95
94
for i , conv_layer in enumerate (self .conv_layers ):
96
95
x = conv_layer (x )
97
96
if self .pool_sizes and i < len (self .pool_sizes ):
98
- x = pool (x , pool_size = self .pool_sizes [i ], padding = 'same' , mode = 'max' )
97
+ x = nn . pool (x , pool_size = self .pool_sizes [i ], padding = 'same' , mode = 'max' )
99
98
if self .dropout :
100
- x = dropout (x , dropout = self .dropout )
101
- out = merge_dims (x , axes = 'static' )
99
+ x = nn . dropout (x , dropout = self .dropout )
100
+ out = nn . merge_dims (x , axes = 'static' )
102
101
return out
103
102
104
103
105
- class ConformerEncoderLayer (Module ):
104
+ class ConformerEncoderLayer (nn . Module ):
106
105
"""
107
106
Represents a conformer block
108
107
"""
@@ -135,36 +134,36 @@ def __init__(self, conv_kernel_size: Tuple[int], ff_act, ff_dim: int, dropout: f
135
134
136
135
def forward (self , inp : LayerRef ) -> LayerRef :
137
136
# FFN
138
- x_ffn1_ln = layer_norm (inp )
137
+ x_ffn1_ln = nn . layer_norm (inp )
139
138
x_ffn1 = self .ffn1 (x_ffn1_ln )
140
- x_ffn1_out = 0.5 * dropout (x_ffn1 , dropout = self .dropout ) + inp
139
+ x_ffn1_out = 0.5 * nn . dropout (x_ffn1 , dropout = self .dropout ) + inp
141
140
142
141
# MHSA
143
- x_mhsa_ln = layer_norm (x_ffn1_out )
142
+ x_mhsa_ln = nn . layer_norm (x_ffn1_out )
144
143
x_mhsa = self .mhsa_module (x_mhsa_ln )
145
144
x_mhsa_out = x_mhsa + x_ffn1_out
146
145
147
146
# Conv
148
- x_conv_ln = layer_norm (x_mhsa_out )
147
+ x_conv_ln = nn . layer_norm (x_mhsa_out )
149
148
x_conv = self .conv_module (x_conv_ln )
150
- x_conv_out = dropout (x_conv , dropout = self .dropout ) + x_mhsa_out
149
+ x_conv_out = nn . dropout (x_conv , dropout = self .dropout ) + x_mhsa_out
151
150
152
151
# FFN
153
- x_ffn2_ln = layer_norm (x_conv_out )
152
+ x_ffn2_ln = nn . layer_norm (x_conv_out )
154
153
x_ffn2 = self .ffn2 (x_ffn2_ln )
155
- x_ffn2_out = 0.5 * dropout (x_ffn2 , dropout = self .dropout ) + x_conv_out
154
+ x_ffn2_out = 0.5 * nn . dropout (x_ffn2 , dropout = self .dropout ) + x_conv_out
156
155
157
156
# last LN layer
158
- return layer_norm (x_ffn2_out )
157
+ return nn . layer_norm (x_ffn2_out )
159
158
160
159
161
- class ConformerEncoder (Module ):
160
+ class ConformerEncoder (nn . Module ):
162
161
"""
163
162
Represents Conformer encoder architecture
164
163
"""
165
164
166
- def __init__ (self , encoder_layer : Module , num_blocks : int , conv_kernel_size : Tuple [int , ...] = (32 ,), ff_act = swish ,
167
- ff_dim : int = 512 , dropout : float = 0.1 , att_dropout : float = 0.1 , enc_key_dim : int = 256 ,
165
+ def __init__ (self , encoder_layer : nn . Module , num_blocks : int , conv_kernel_size : Tuple [int , ...] = (32 ,),
166
+ ff_act = nn . swish , ff_dim : int = 512 , dropout : float = 0.1 , att_dropout : float = 0.1 , enc_key_dim : int = 256 ,
168
167
att_n_heads : int = 4 , l2 : float = 0.0 ):
169
168
"""
170
169
:param encoder_layer:
@@ -186,9 +185,9 @@ def __init__(self, encoder_layer: Module, num_blocks: int, conv_kernel_size: Tup
186
185
filter_sizes = [(3 , 3 ), (3 , 3 )], pool_sizes = [(2 , 2 ), (2 , 2 )], channel_sizes = [enc_key_dim , enc_key_dim ],
187
186
l2 = l2 , dropout = dropout )
188
187
189
- self .linear = Linear (n_out = enc_key_dim , l2 = l2 , with_bias = False )
188
+ self .linear = nn . Linear (n_out = enc_key_dim , l2 = l2 , with_bias = False )
190
189
191
- self .conformer_blocks = ModuleList ([
190
+ self .conformer_blocks = nn . ModuleList ([
192
191
encoder_layer (
193
192
conv_kernel_size = conv_kernel_size , ff_act = ff_act , ff_dim = ff_dim , dropout = dropout ,
194
193
att_dropout = att_dropout , enc_key_dim = enc_key_dim , att_n_heads = att_n_heads , l2 = l2
@@ -199,7 +198,7 @@ def __init__(self, encoder_layer: Module, num_blocks: int, conv_kernel_size: Tup
199
198
def forward (self , inp : LayerRef ) -> LayerRef :
200
199
x_subsample = self .conv_subsample_layer (inp )
201
200
x_linear = self .linear (x_subsample )
202
- x = dropout (x_linear , dropout = self .dropout )
201
+ x = nn . dropout (x_linear , dropout = self .dropout )
203
202
for conformer_block in self .conformer_blocks :
204
203
x = conformer_block (x )
205
204
return x
0 commit comments