Skip to content

Commit 6e3a651

Browse files
committed
fix params naming
1 parent d2dc800 commit 6e3a651

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

nn/conformer.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ class _PositionwiseFeedForward(nn.Module):
1414
FF -> Activation -> Dropout -> FF
1515
"""
1616

17-
def __init__(self, d_model: int, d_ff: int, dropout: float, activation, l2: float = 0.0):
17+
def __init__(self, dim_model: int, dim_ff: int, dropout: float, activation, l2: float = 0.0):
1818
"""
19-
:param d_model:
20-
:param d_ff:
19+
:param dim_model:
20+
:param dim_ff:
2121
:param dropout:
2222
:param activation:
2323
:param l2:
@@ -27,8 +27,8 @@ def __init__(self, d_model: int, d_ff: int, dropout: float, activation, l2: floa
2727
self.dropout = dropout
2828
self.activation = activation
2929

30-
self.linear1 = nn.Linear(n_out=d_ff, l2=l2)
31-
self.linear2 = nn.Linear(n_out=d_model, l2=l2)
30+
self.linear1 = nn.Linear(n_out=dim_ff, l2=l2)
31+
self.linear2 = nn.Linear(n_out=dim_model, l2=l2)
3232

3333
def forward(self, inp: LayerRef) -> LayerRef:
3434
return self.linear2(nn.dropout(self.activation(self.linear1(inp)), dropout=self.dropout))
@@ -40,17 +40,17 @@ class _ConformerConvBlock(nn.Module):
4040
FF -> GLU -> depthwise conv -> BN -> Swish -> FF
4141
"""
4242

43-
def __init__(self, d_model: int, kernel_size: Tuple[int], l2: float = 0.0):
43+
def __init__(self, dim_model: int, kernel_size: Tuple[int], l2: float = 0.0):
4444
"""
45-
:param d_model:
45+
:param dim_model:
4646
:param kernel_size:
4747
:param l2:
4848
"""
4949
super().__init__()
5050

51-
self.positionwise_conv1 = nn.Linear(n_out=d_model * 2, l2=l2)
52-
self.depthwise_conv = nn.Conv(n_out=d_model, filter_size=kernel_size, groups=d_model, l2=l2, padding='same')
53-
self.positionwise_conv2 = nn.Linear(n_out=d_model, l2=l2)
51+
self.positionwise_conv1 = nn.Linear(n_out=dim_model * 2, l2=l2)
52+
self.depthwise_conv = nn.Conv(n_out=dim_model, filter_size=kernel_size, groups=dim_model, l2=l2, padding='same')
53+
self.positionwise_conv2 = nn.Linear(n_out=dim_model, l2=l2)
5454

5555
def forward(self, inp: LayerRef) -> LayerRef:
5656
x_conv1 = self.positionwise_conv1(inp)
@@ -68,15 +68,15 @@ class _ConformerConvSubsampleLayer(nn.Module):
6868
"""
6969

7070
def __init__(self, filter_sizes: List[Tuple[int, ...]], pool_sizes: Union[List[Tuple[int, ...]], None],
71-
channel_sizes: List[int], l2: float = 0.0, dropout: float = 0.3, act: str = 'relu',
71+
channel_sizes: List[int], l2: float = 0.0, dropout: float = 0.3, activation: str = 'relu',
7272
padding: str = 'same'):
7373
"""
7474
:param filter_sizes:
7575
:param pool_sizes:
7676
:param channel_sizes:
7777
:param l2:
7878
:param dropout:
79-
:param act:
79+
:param activation:
8080
:param padding:
8181
"""
8282
super().__init__()
@@ -87,7 +87,7 @@ def __init__(self, filter_sizes: List[Tuple[int, ...]], pool_sizes: Union[List[T
8787
self.conv_layers = nn.ModuleList()
8888
for filter_size, channel_size in zip(filter_sizes, channel_sizes):
8989
self.conv_layers.append(
90-
nn.Conv(l2=l2, activation=act, filter_size=filter_size, n_out=channel_size, padding=padding))
90+
nn.Conv(l2=l2, activation=activation, filter_size=filter_size, n_out=channel_size, padding=padding))
9191

9292
def forward(self, inp: LayerRef) -> LayerRef:
9393
x = nn.split_dims(inp, axis='F', dims=(-1, 1))
@@ -106,31 +106,31 @@ class ConformerEncoderLayer(nn.Module):
106106
Represents a conformer block
107107
"""
108108

109-
def __init__(self, conv_kernel_size: Tuple[int], ff_act, ff_dim: int, dropout: float, att_dropout: float,
110-
enc_key_dim: int, att_n_heads: int, l2: float):
109+
def __init__(self, conv_kernel_size: Tuple[int], activation_ff, dim_ff: int, dropout: float, att_dropout: float,
110+
enc_key_dim: int, num_heads: int, l2: float):
111111
"""
112112
:param conv_kernel_size:
113-
:param ff_act:
113+
:param activation_ff:
114114
:param ff_dim:
115115
:param dropout:
116116
:param att_dropout:
117117
:param enc_key_dim:
118-
:param att_n_heads:
118+
:param num_heads:
119119
:param l2:
120120
"""
121121
super().__init__()
122122

123123
self.dropout = dropout
124124

125125
self.ffn1 = _PositionwiseFeedForward(
126-
d_model=enc_key_dim, d_ff=ff_dim, dropout=dropout, activation=ff_act, l2=l2)
126+
dim_model=enc_key_dim, dim_ff=dim_ff, dropout=dropout, activation=activation_ff, l2=l2)
127127

128128
self.ffn2 = _PositionwiseFeedForward(
129-
d_model=enc_key_dim, d_ff=ff_dim, dropout=dropout, activation=ff_act, l2=l2)
129+
dim_model=enc_key_dim, dim_ff=dim_ff, dropout=dropout, activation=activation_ff, l2=l2)
130130

131-
self.conv_module = _ConformerConvBlock(d_model=enc_key_dim, kernel_size=conv_kernel_size)
131+
self.conv_module = _ConformerConvBlock(dim_model=enc_key_dim, kernel_size=conv_kernel_size)
132132

133-
self.mhsa_module = MultiheadAttention(d_model, att_n_heads, dropout=att_dropout) # TODO: to be implemented
133+
self.mhsa_module = self.conv_module #MultiheadAttention(enc_key_dim, num_heads, dropout=att_dropout) # TODO: to be implemented
134134

135135
def forward(self, inp: LayerRef) -> LayerRef:
136136
# FFN
@@ -163,8 +163,8 @@ class ConformerEncoder(nn.Module):
163163
"""
164164

165165
def __init__(self, encoder_layer: nn.Module, num_blocks: int, conv_kernel_size: Tuple[int, ...] = (32,),
166-
ff_act=nn.swish, ff_dim: int = 512, dropout: float = 0.1, att_dropout: float = 0.1, enc_key_dim: int = 256,
167-
att_n_heads: int = 4, l2: float = 0.0):
166+
activation_ff=nn.swish, dim_ff: int = 512, dropout: float = 0.1, att_dropout: float = 0.1, enc_key_dim: int = 256,
167+
num_heads: int = 4, l2: float = 0.0):
168168
"""
169169
:param encoder_layer:
170170
:param num_blocks:
@@ -189,8 +189,8 @@ def __init__(self, encoder_layer: nn.Module, num_blocks: int, conv_kernel_size:
189189

190190
self.conformer_blocks = nn.Sequential([
191191
encoder_layer(
192-
conv_kernel_size=conv_kernel_size, ff_act=ff_act, ff_dim=ff_dim, dropout=dropout,
193-
att_dropout=att_dropout, enc_key_dim=enc_key_dim, att_n_heads=att_n_heads, l2=l2
192+
conv_kernel_size=conv_kernel_size, activation_ff=activation_ff, dim_ff=dim_ff, dropout=dropout,
193+
att_dropout=att_dropout, enc_key_dim=enc_key_dim, num_heads=num_heads, l2=l2
194194
)
195195
for _ in range(num_blocks)
196196
])
@@ -200,4 +200,4 @@ def forward(self, inp: LayerRef) -> LayerRef:
200200
x_linear = self.linear(x_subsample)
201201
x = nn.dropout(x_linear, dropout=self.dropout)
202202
x = self.conformer_blocks(x)
203-
return x
203+
return x

0 commit comments

Comments
 (0)