diff --git a/i6_models/parts/conformer/feedforward.py b/i6_models/parts/conformer/feedforward.py new file mode 100644 index 00000000..3a064ee5 --- /dev/null +++ b/i6_models/parts/conformer/feedforward.py @@ -0,0 +1,48 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Callable + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration + + +@dataclass +class ConformerPositionwiseFeedForwardV1Config(ModelConfiguration): + input_dim: int + """input dimension""" + hidden_dim: int + """hidden dimension (normally set to 4*input_dim as suggested by the paper)""" + dropout: float + """dropout probability""" + activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.silu + """activation function""" + + +class ConformerPositionwiseFeedForwardV1(nn.Module): + """ + Conformer feedforward module + """ + + def __init__(self, cfg: ConformerPositionwiseFeedForwardV1Config): + super().__init__() + + self.layer_norm = nn.LayerNorm(cfg.input_dim) + self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim, bias=True) + self.activation = cfg.activation + self.linear_out = nn.Linear(in_features=cfg.hidden_dim, out_features=cfg.input_dim, bias=True) + self.dropout = cfg.dropout + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: shape [B,T,F], F=input_dim + :return: shape [B,T,F], F=input_dim + """ + tensor = self.layer_norm(tensor) + tensor = self.linear_ff(tensor) # [B,T,F] + tensor = self.activation(tensor) # [B,T,F] + tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F] + tensor = self.linear_out(tensor) # [B,T,F] + tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F] + return tensor diff --git a/requirements.txt b/requirements.txt index 704f4b86..6132274c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -typeguard \ No newline at end of file +typeguard +torch diff --git a/tests/test_conformer.py b/tests/test_conformer.py new file mode 100644 index 00000000..d539af30 --- /dev/null +++ b/tests/test_conformer.py @@ -0,0 +1,24 @@ +from itertools import product + +import torch +from torch import nn + +from i6_models.parts.conformer.feedforward import ( + ConformerPositionwiseFeedForwardV1, + ConformerPositionwiseFeedForwardV1Config, +) + + +def test_ConformerPositionwiseFeedForwardV1(): + def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation): + x = torch.randn(input_shape) + cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation) + conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg) + y = conf_ffn_part(x) + return y.shape + + for input_dim, hidden_dim, dropout, activation in product( + [10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu] + ): + input_shape = (10, 100, input_dim) + assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape