Skip to content

Commit 4feeabe

Browse files
committed
better import
1 parent 71d086c commit 4feeabe

File tree

1 file changed

+35
-36
lines changed

1 file changed

+35
-36
lines changed

nn/conformer.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
"""
55

66
from typing import Tuple, List, Union
7+
from .. import nn
8+
from . import LayerRef
79

8-
from . import Module, ModuleList, LayerRef, Linear, dropout, layer_norm, batch_norm, Conv, swish, glu, split_dims, \
9-
merge_dims, pool
1010

11-
12-
class _PositionwiseFeedForward(Module):
11+
class _PositionwiseFeedForward(nn.Module):
1312
"""
1413
Conformer position-wise feedforward neural network layer
1514
FF -> Activation -> Dropout -> FF
@@ -28,14 +27,14 @@ def __init__(self, d_model: int, d_ff: int, dropout: float, activation, l2: floa
2827
self.dropout = dropout
2928
self.activation = activation
3029

31-
self.linear1 = Linear(n_out=d_ff, l2=l2)
32-
self.linear2 = Linear(n_out=d_model, l2=l2)
30+
self.linear1 = nn.Linear(n_out=d_ff, l2=l2)
31+
self.linear2 = nn.Linear(n_out=d_model, l2=l2)
3332

3433
def forward(self, inp: LayerRef) -> LayerRef:
35-
return self.linear2(dropout(self.activation(self.linear1(inp)), dropout=self.dropout))
34+
return self.linear2(nn.dropout(self.activation(self.linear1(inp)), dropout=self.dropout))
3635

3736

38-
class _ConformerConvBlock(Module):
37+
class _ConformerConvBlock(nn.Module):
3938
"""
4039
Conformer convolution block
4140
FF -> GLU -> depthwise conv -> BN -> Swish -> FF
@@ -49,21 +48,21 @@ def __init__(self, d_model: int, kernel_size: Tuple[int], l2: float = 0.0):
4948
"""
5049
super().__init__()
5150

52-
self.positionwise_conv1 = Linear(n_out=d_model * 2, l2=l2)
53-
self.depthwise_conv = Conv(n_out=d_model, filter_size=kernel_size, groups=d_model, l2=l2, padding='same')
54-
self.positionwise_conv2 = Linear(n_out=d_model, l2=l2)
51+
self.positionwise_conv1 = nn.Linear(n_out=d_model * 2, l2=l2)
52+
self.depthwise_conv = nn.Conv(n_out=d_model, filter_size=kernel_size, groups=d_model, l2=l2, padding='same')
53+
self.positionwise_conv2 = nn.Linear(n_out=d_model, l2=l2)
5554

5655
def forward(self, inp: LayerRef) -> LayerRef:
5756
x_conv1 = self.positionwise_conv1(inp)
58-
x_act = glu(x_conv1)
57+
x_act = nn.glu(x_conv1)
5958
x_depthwise_conv = self.depthwise_conv(x_act)
60-
x_bn = batch_norm(x_depthwise_conv)
61-
x_swish = swish(x_bn)
59+
x_bn = nn.batch_norm(x_depthwise_conv)
60+
x_swish = nn.swish(x_bn)
6261
x_conv2 = self.positionwise_conv2(x_swish)
6362
return x_conv2
6463

6564

66-
class _ConformerConvSubsampleLayer(Module):
65+
class _ConformerConvSubsampleLayer(nn.Module):
6766
"""
6867
Conv 2D block with optional max-pooling
6968
"""
@@ -85,24 +84,24 @@ def __init__(self, filter_sizes: List[Tuple[int, ...]], pool_sizes: Union[List[T
8584
self.dropout = dropout
8685
self.pool_sizes = pool_sizes
8786

88-
self.conv_layers = ModuleList()
87+
self.conv_layers = nn.ModuleList()
8988
for filter_size, channel_size in zip(filter_sizes, channel_sizes):
9089
self.conv_layers.append(
91-
Conv(l2=l2, activation=act, filter_size=filter_size, n_out=channel_size, padding=padding))
90+
nn.Conv(l2=l2, activation=act, filter_size=filter_size, n_out=channel_size, padding=padding))
9291

9392
def forward(self, inp: LayerRef) -> LayerRef:
94-
x = split_dims(inp, axis='F', dims=(-1, 1))
93+
x = nn.split_dims(inp, axis='F', dims=(-1, 1))
9594
for i, conv_layer in enumerate(self.conv_layers):
9695
x = conv_layer(x)
9796
if self.pool_sizes and i < len(self.pool_sizes):
98-
x = pool(x, pool_size=self.pool_sizes[i], padding='same', mode='max')
97+
x = nn.pool(x, pool_size=self.pool_sizes[i], padding='same', mode='max')
9998
if self.dropout:
100-
x = dropout(x, dropout=self.dropout)
101-
out = merge_dims(x, axes='static')
99+
x = nn.dropout(x, dropout=self.dropout)
100+
out = nn.merge_dims(x, axes='static')
102101
return out
103102

104103

105-
class ConformerEncoderLayer(Module):
104+
class ConformerEncoderLayer(nn.Module):
106105
"""
107106
Represents a conformer block
108107
"""
@@ -135,36 +134,36 @@ def __init__(self, conv_kernel_size: Tuple[int], ff_act, ff_dim: int, dropout: f
135134

136135
def forward(self, inp: LayerRef) -> LayerRef:
137136
# FFN
138-
x_ffn1_ln = layer_norm(inp)
137+
x_ffn1_ln = nn.layer_norm(inp)
139138
x_ffn1 = self.ffn1(x_ffn1_ln)
140-
x_ffn1_out = 0.5 * dropout(x_ffn1, dropout=self.dropout) + inp
139+
x_ffn1_out = 0.5 * nn.dropout(x_ffn1, dropout=self.dropout) + inp
141140

142141
# MHSA
143-
x_mhsa_ln = layer_norm(x_ffn1_out)
142+
x_mhsa_ln = nn.layer_norm(x_ffn1_out)
144143
x_mhsa = self.mhsa_module(x_mhsa_ln)
145144
x_mhsa_out = x_mhsa + x_ffn1_out
146145

147146
# Conv
148-
x_conv_ln = layer_norm(x_mhsa_out)
147+
x_conv_ln = nn.layer_norm(x_mhsa_out)
149148
x_conv = self.conv_module(x_conv_ln)
150-
x_conv_out = dropout(x_conv, dropout=self.dropout) + x_mhsa_out
149+
x_conv_out = nn.dropout(x_conv, dropout=self.dropout) + x_mhsa_out
151150

152151
# FFN
153-
x_ffn2_ln = layer_norm(x_conv_out)
152+
x_ffn2_ln = nn.layer_norm(x_conv_out)
154153
x_ffn2 = self.ffn2(x_ffn2_ln)
155-
x_ffn2_out = 0.5 * dropout(x_ffn2, dropout=self.dropout) + x_conv_out
154+
x_ffn2_out = 0.5 * nn.dropout(x_ffn2, dropout=self.dropout) + x_conv_out
156155

157156
# last LN layer
158-
return layer_norm(x_ffn2_out)
157+
return nn.layer_norm(x_ffn2_out)
159158

160159

161-
class ConformerEncoder(Module):
160+
class ConformerEncoder(nn.Module):
162161
"""
163162
Represents Conformer encoder architecture
164163
"""
165164

166-
def __init__(self, encoder_layer: Module, num_blocks: int, conv_kernel_size: Tuple[int, ...] = (32,), ff_act=swish,
167-
ff_dim: int = 512, dropout: float = 0.1, att_dropout: float = 0.1, enc_key_dim: int = 256,
165+
def __init__(self, encoder_layer: nn.Module, num_blocks: int, conv_kernel_size: Tuple[int, ...] = (32,),
166+
ff_act=nn.swish, ff_dim: int = 512, dropout: float = 0.1, att_dropout: float = 0.1, enc_key_dim: int = 256,
168167
att_n_heads: int = 4, l2: float = 0.0):
169168
"""
170169
:param encoder_layer:
@@ -186,9 +185,9 @@ def __init__(self, encoder_layer: Module, num_blocks: int, conv_kernel_size: Tup
186185
filter_sizes=[(3, 3), (3, 3)], pool_sizes=[(2, 2), (2, 2)], channel_sizes=[enc_key_dim, enc_key_dim],
187186
l2=l2, dropout=dropout)
188187

189-
self.linear = Linear(n_out=enc_key_dim, l2=l2, with_bias=False)
188+
self.linear = nn.Linear(n_out=enc_key_dim, l2=l2, with_bias=False)
190189

191-
self.conformer_blocks = ModuleList([
190+
self.conformer_blocks = nn.ModuleList([
192191
encoder_layer(
193192
conv_kernel_size=conv_kernel_size, ff_act=ff_act, ff_dim=ff_dim, dropout=dropout,
194193
att_dropout=att_dropout, enc_key_dim=enc_key_dim, att_n_heads=att_n_heads, l2=l2
@@ -199,7 +198,7 @@ def __init__(self, encoder_layer: Module, num_blocks: int, conv_kernel_size: Tup
199198
def forward(self, inp: LayerRef) -> LayerRef:
200199
x_subsample = self.conv_subsample_layer(inp)
201200
x_linear = self.linear(x_subsample)
202-
x = dropout(x_linear, dropout=self.dropout)
201+
x = nn.dropout(x_linear, dropout=self.dropout)
203202
for conformer_block in self.conformer_blocks:
204203
x = conformer_block(x)
205204
return x

0 commit comments

Comments
 (0)