diff --git a/i6_models/parts/conformer/mhsa.py b/i6_models/parts/conformer/mhsa.py new file mode 100644 index 00000000..0ef48ff7 --- /dev/null +++ b/i6_models/parts/conformer/mhsa.py @@ -0,0 +1,54 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Optional, Callable +import torch + +from i6_models.config import ModelConfiguration + + +@dataclass +class ConformerMHSAV1Config(ModelConfiguration): + input_dim: int + """input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`""" + num_att_heads: int + """number of attention heads""" + att_weights_dropout: float + """attention weights dropout""" + dropout: float + """multi-headed self attention output dropout""" + + def __post_init__(self) -> None: + super().__post_init__() + assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads" + + +class ConformerMHSAV1(torch.nn.Module): + """ + Conformer multi-headed self-attention module + """ + + def __init__(self, cfg: ConformerMHSAV1Config): + + super().__init__() + + self.layernorm = torch.nn.LayerNorm(cfg.input_dim) + self.mhsa = torch.nn.MultiheadAttention( + cfg.input_dim, cfg.num_att_heads, dropout=cfg.att_weights_dropout, batch_first=True + ) + self.dropout = cfg.dropout + + def forward(self, input_tensor: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Apply layer norm and multi-head self attention and dropout + :param Optional[torch.Tensor] key_padding_mask: could be a binary or float mask of shape (B, T) + which will be applied/added to dot product, used to mask padded key positions out + """ + + output_tensor = self.layernorm(input_tensor) # [B,T,F] + + output_tensor, _ = self.mhsa( + output_tensor, output_tensor, output_tensor, key_padding_mask=key_padding_mask, need_weights=False + ) # [B,T,F] + output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F] + + return output_tensor diff --git a/tests/test_conformer.py b/tests/test_conformer.py index 444c5555..81349f78 100644 --- a/tests/test_conformer.py +++ b/tests/test_conformer.py @@ -1,3 +1,4 @@ +from __future__ import annotations from itertools import product import torch @@ -8,24 +9,10 @@ ConformerPositionwiseFeedForwardV1, ConformerPositionwiseFeedForwardV1Config, ) +from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config, ConformerMHSAV1 from i6_models.parts.conformer.norm import LayerNormNC -def test_ConformerPositionwiseFeedForwardV1(): - def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation): - x = torch.randn(input_shape) - cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation) - conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg) - y = conf_ffn_part(x) - return y.shape - - for input_dim, hidden_dim, dropout, activation in product( - [10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu] - ): - input_shape = (10, 100, input_dim) - assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape - - def test_conformer_convolution_output_shape(): def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0.1, activation=nn.functional.silu): x = torch.randn(batch, time, features) @@ -48,6 +35,40 @@ def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0 assert get_output_shape(10, 10, 20, kernel_size=32) == (10, 10, 20) # even kernel size +def test_ConformerPositionwiseFeedForwardV1(): + def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation): + x = torch.randn(input_shape) + cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation) + conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg) + y = conf_ffn_part(x) + return y.shape + + for input_dim, hidden_dim, dropout, activation in product( + [10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu] + ): + input_shape = (10, 100, input_dim) + assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape + + +def test_ConformerMHSAV1(): + def get_output_shape(input_shape, cfg, **kwargs): + + input = torch.randn(input_shape) + output = ConformerMHSAV1(cfg)(input, **kwargs) + + return list(output.shape) + + # without key padding mask + input_shape = [3, 10, 20] # B,T,F + cfg = ConformerMHSAV1Config(20, 4, 0.1, 0.1) + assert get_output_shape(input_shape, cfg) == [3, 10, 20] + + # with key padding mask + input_shape = [4, 15, 32] # B,T,F + cfg = ConformerMHSAV1Config(32, 8, 0.2, 0.3) + assert get_output_shape(input_shape, cfg, key_padding_mask=torch.randint(0, 2, input_shape[:2]) > 0) == [4, 15, 32] + + def test_layer_norm_nc(): torch.manual_seed(42)