Skip to content

Commit

Permalink
feat: working on landmark embed
Browse files Browse the repository at this point in the history
  • Loading branch information
rileydrizzy committed Mar 29, 2024
1 parent 9a0a275 commit 850ef81
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions signa2text/src/models/baseline_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
This module contains the implementation of a Transformer model for sign language tasks.
Classes:
- TokenEmbedding: Create embedding for the target seqeunce
- LandmarkEmbedding: Create embedding for the source(frames)seqeunce
- TokenEmbedding: Create embedding for the target sequence
- LandmarkEmbedding: Create embedding for the source(frames)sequence
- Encoder: Implements the transformer encoder stack.
- Decoder: Implements the transformer decoder stack.
- Transformer: The main transformer model class with methods for training and inference.
Expand Down Expand Up @@ -126,7 +126,9 @@ def __init__(
Dropout rate, by default 0.1
"""
super().__init__()
self.multi_attention = nn.MultiheadAttention(embedding_dim, num_heads)
self.multi_attention = nn.MultiheadAttention(
embedding_dim, num_heads, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(embedding_dim, feed_forward_dim),
nn.ReLU(),
Expand All @@ -144,14 +146,14 @@ def forward(self, inputs_x):
multi_attention_out = self.dropout1(multi_attention_out)

# Residual connection and layer normalization
out1 = self.layernorm1(inputs_x + multi_attention_out)
outputs_1 = self.layernorm1(inputs_x + multi_attention_out)

# Feed-forward layer
ffn_out = self.ffn(out1)
ffn_out = self.ffn(outputs_1)
ffn_out = self.dropout2(ffn_out)

# Residual connection and layer normalization
x = self.layernorm2(out1 + ffn_out)
x = self.layernorm2(outputs_1 + ffn_out)

return x

Expand Down

0 comments on commit 850ef81

Please sign in to comment.