Skip to content
54 changes: 54 additions & 0 deletions i6_models/parts/conformer/mhsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Callable
import torch

from i6_models.config import ModelConfiguration


@dataclass
class ConformerMHSAV1Config(ModelConfiguration):
input_dim: int
"""input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`"""
num_att_heads: int
"""number of attention heads"""
att_weights_dropout: float
"""attention weights dropout"""
dropout: float
"""multi-headed self attention output dropout"""

def __post_init__(self) -> None:
super().__post_init__()
assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads"


class ConformerMHSAV1(torch.nn.Module):
"""
Conformer multi-headed self-attention module
"""

def __init__(self, cfg: ConformerMHSAV1Config):

super().__init__()

self.layernorm = torch.nn.LayerNorm(cfg.input_dim)
self.mhsa = torch.nn.MultiheadAttention(
cfg.input_dim, cfg.num_att_heads, dropout=cfg.att_weights_dropout, batch_first=True
)
self.dropout = cfg.dropout

def forward(self, input_tensor: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Apply layer norm and multi-head self attention and dropout
:param Optional[torch.Tensor] key_padding_mask: could be a binary or float mask of shape (B, T)
which will be applied/added to dot product, used to mask padded key positions out
"""

output_tensor = self.layernorm(input_tensor) # [B,T,F]

output_tensor, _ = self.mhsa(
output_tensor, output_tensor, output_tensor, key_padding_mask=key_padding_mask, need_weights=False
) # [B,T,F]
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]

return output_tensor
51 changes: 36 additions & 15 deletions tests/test_conformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from itertools import product

import torch
Expand All @@ -8,24 +9,10 @@
ConformerPositionwiseFeedForwardV1,
ConformerPositionwiseFeedForwardV1Config,
)
from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config, ConformerMHSAV1
from i6_models.parts.conformer.norm import LayerNormNC


def test_ConformerPositionwiseFeedForwardV1():
def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation):
x = torch.randn(input_shape)
cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation)
conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg)
y = conf_ffn_part(x)
return y.shape

for input_dim, hidden_dim, dropout, activation in product(
[10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu]
):
input_shape = (10, 100, input_dim)
assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape


def test_conformer_convolution_output_shape():
def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0.1, activation=nn.functional.silu):
x = torch.randn(batch, time, features)
Expand All @@ -48,6 +35,40 @@ def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0
assert get_output_shape(10, 10, 20, kernel_size=32) == (10, 10, 20) # even kernel size


def test_ConformerPositionwiseFeedForwardV1():
def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation):
x = torch.randn(input_shape)
cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation)
conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg)
y = conf_ffn_part(x)
return y.shape

for input_dim, hidden_dim, dropout, activation in product(
[10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu]
):
input_shape = (10, 100, input_dim)
assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape


def test_ConformerMHSAV1():
def get_output_shape(input_shape, cfg, **kwargs):

input = torch.randn(input_shape)
output = ConformerMHSAV1(cfg)(input, **kwargs)

return list(output.shape)

# without key padding mask
input_shape = [3, 10, 20] # B,T,F
cfg = ConformerMHSAV1Config(20, 4, 0.1, 0.1)
assert get_output_shape(input_shape, cfg) == [3, 10, 20]

# with key padding mask
input_shape = [4, 15, 32] # B,T,F
cfg = ConformerMHSAV1Config(32, 8, 0.2, 0.3)
assert get_output_shape(input_shape, cfg, key_padding_mask=torch.randint(0, 2, input_shape[:2]) > 0) == [4, 15, 32]


def test_layer_norm_nc():
torch.manual_seed(42)

Expand Down