Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/encoder #7

Merged
merged 22 commits into from
Oct 17, 2020
Merged
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style: apply isort and black
inmoonlight committed Sep 16, 2020
commit a67a773a49884166dc669ae0f3eda1b294aaf195
11 changes: 8 additions & 3 deletions src/dataset.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from .utils import get_configs, read_lines, load_tokenizer
from .utils import get_configs, load_tokenizer, read_lines


class WMT14Dataset(Dataset):
@@ -28,7 +28,9 @@ def __init__(
def __len__(self) -> int:
return len(self.source_lines)

def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def __getitem__(
self, index: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
source_encoded, source_mask, target_encoded, target_mask, = self.process(
self.source_lines[index], self.target_lines[index]
)
@@ -102,7 +104,10 @@ def process(
torch.tensor(source_encoded),
torch.tensor(source_mask),
)
target_encoded, target_mask = (torch.tensor(target_encoded), torch.tensor(target_mask))
target_encoded, target_mask = (
torch.tensor(target_encoded),
torch.tensor(target_mask),
)
return source_encoded, source_mask, target_encoded, target_mask


57 changes: 41 additions & 16 deletions src/modules.py
Original file line number Diff line number Diff line change
@@ -13,20 +13,31 @@ class PositionalEncoding(nn.Module):
max_len: maximum length of the tokens
embedding_dim: embedding dimension of the given token
"""

def __init__(self, max_len: int, embedding_dim: int) -> None:
super().__init__()
config = get_config('model')
config = get_config("model")
self.dropout = nn.Dropout(p=config.pe_dropout)
positional_encoding = torch.zeros(max_len, embedding_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() / embedding_dim * math.log(1e4))
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
1
) # (max_len, 1)
div_term = torch.exp(
torch.arange(0, embedding_dim, 2).float() / embedding_dim * math.log(1e4)
)
positional_encoding[:, 0::2] = torch.sin(position / div_term)
positional_encoding[:, 1::2] = torch.cos(position / div_term)
positional_encoding = positional_encoding.unsqueeze(0).transpose(0, 1) # (max_len, 1, embedding_dim)
self.register_buffer('positional_encoding', positional_encoding) # TODO: register_buffer?
positional_encoding = positional_encoding.unsqueeze(0).transpose(
0, 1
) # (max_len, 1, embedding_dim)
self.register_buffer(
"positional_encoding", positional_encoding
) # TODO: register_buffer?

def forward(self, embeddings: Tensor) -> Tensor:
embeddings = embeddings + self.positional_encoding # (batch_size, max_len, embedding_dim)
embeddings = (
embeddings + self.positional_encoding
) # (batch_size, max_len, embedding_dim)
embeddings = self.dropout(embeddings)
return embeddings

@@ -37,18 +48,23 @@ class Attention(nn.Module):
Attributes:
attention_mask: whether to mask attention or not
"""

def __init__(self, attention_mask: bool = False) -> None:
super().__init__()
self.attention_mask = attention_mask
self.config = get_config('model')
self.config = get_config("model")
self.dim_q = self.config.dim_q
self.dim_k = self.config.dim_k
self.dim_v = self.config.dim_v
self.dim_model = self.config.dim_model
if attention_mask:
assert (self.dim_k == self.dim_v), "masked self-attention requires key, and value to be of the same size"
assert (
self.dim_k == self.dim_v
), "masked self-attention requires key, and value to be of the same size"
else:
assert (self.dim_q == self.dim_k == self.dim_v), "self-attention requires query, key, and value to be of the same size"
assert (
self.dim_q == self.dim_k == self.dim_v
), "self-attention requires query, key, and value to be of the same size"

self.q_project = nn.Linear(self.dim_model, self.dim_q)
self.k_project = nn.Linear(self.dim_model, self.dim_k)
@@ -61,7 +77,9 @@ def forward(self, embeddings: Tensor, mask: Tensor) -> Tensor:
q = self.q_project(embeddings) # (batch_size, max_len, dim_q)
k = self.k_project(embeddings) # (batch_size, max_len, dim_k)
v = self.v_project(embeddings) # (batch_size, max_len, dim_v)
qk = torch.bmm(q, k.transpose(1, 2)) * self.scale # (batch_size, max_len, max_len)
qk = (
torch.bmm(q, k.transpose(1, 2)) * self.scale
) # (batch_size, max_len, max_len)
qk = qk.masked_fill(mask == 0, self.config.train_hparams.eps)
attention_weight = torch.softmax(qk, dim=-1)
attention = torch.matmul(attention_weight, v) # (batch_size, max_len, dim_v)
@@ -75,29 +93,34 @@ class MultiHeadAttention(nn.Module):
Attributes:
attention_mask: whether to mask attention or not
"""

def __init__(self, attention_mask: bool = False):
super().__init__()
self.attention = Attention(attention_mask)
config = get_config('model')
config = get_config("model")
self.batch_size = config.train_hparams.batch_size
self.dim_model = config.dim_model
self.dim_v = config.dim_v
self.num_heads = config.num_heads
assert (self.dim_model // self.num_heads) == self.dim_v
assert (self.dim_model % self.num_heads == 0), "embed_dim must be divisible by num_heads"
assert (
self.dim_model % self.num_heads == 0
), "embed_dim must be divisible by num_heads"
self.linear = nn.Linear(self.num_heads * self.dim_v, self.dim_model)

def forward(self, embeddings: Tensor, mask: Tensor) -> Tensor:
heads = [self.attention(embeddings, mask)[0] for h in range(self.num_heads)]
multihead = torch.cat(heads, dim=-1) # (batch_size, max_len, dim_model * num_heads)
multihead = torch.cat(
heads, dim=-1
) # (batch_size, max_len, dim_model * num_heads)
multihead = self.linear(multihead) # (batch_size, max_len, dim_model)
return multihead


class FeedForwardNetwork(nn.Module):
def __init__(self):
super().__init__()
config = get_config('model')
config = get_config("model")
self.dim_model = config.dim_model
self.dim_ff = config.dim_ff
self.linear1 = nn.Linear(self.dim_model, self.dim_ff, bias=True)
@@ -114,7 +137,7 @@ def forward(self, embeddings: Tensor) -> Tensor:
class LayerNorm(nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
config = get_config('model')
config = get_config("model")
self.dim_model = config.dim_model
self.gamma = nn.Parameter(torch.ones(self.dim_model))
self.beta = nn.Parameter(torch.zeros(self.dim_model))
@@ -123,5 +146,7 @@ def __init__(self, eps: float = 1e-6):
def forward(self, embeddings: Tensor) -> Tensor:
mean = torch.mean(embeddings, dim=-1, keepdim=True) # (batch_size, max_len, 1)
std = torch.std(embeddings, dim=-1, keepdim=True) # (batch_size, max_len, 1)
ln = self.gamma * (embeddings - mean) / (std + self.eps) + self.beta # (batch_size, max_len, dim_model)
ln = (
self.gamma * (embeddings - mean) / (std + self.eps) + self.beta
) # (batch_size, max_len, dim_model)
return ln
22 changes: 16 additions & 6 deletions src/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import torch.nn as nn

from .modules import (FeedForwardNetwork, LayerNorm, MultiHeadAttention,
PositionalEncoding)
from .modules import (
FeedForwardNetwork,
LayerNorm,
MultiHeadAttention,
PositionalEncoding,
)
from .utils import get_config, get_configs, load_tokenizer


@@ -10,6 +14,7 @@ class Model(nn.Module):

Attributes:
"""

def __init__(self):
super().__init__()
# TODO: embeddings
@@ -24,15 +29,18 @@ def forward(self):
class Embeddings(nn.Module):
"""Input embeddings with positional encoding
"""

def __init__(self, langpair):
super().__init__()
# TODO: support transformer-base and transformer-big
config = get_configs('model', 'tokenizer', langpair=langpair)
config = get_configs("model", "tokenizer", langpair=langpair)
self.dim_model = config.model.dim_model
self.vocab_size = config.tokenizer.vocab_size
tokenizer = load_tokenizer(config.tokenizer)
padding_idx = tokenizer.token_to_id('<pad>')
self.embedding_matrix = nn.Embedding(self.vocab_size, self.dim_model, padding_idx=padding_idx)
padding_idx = tokenizer.token_to_id("<pad>")
self.embedding_matrix = nn.Embedding(
self.vocab_size, self.dim_model, padding_idx=padding_idx
)
self.scale = self.dim_model ** 0.5
self.max_len = config.model.max_len
self.positional_encoding = PositionalEncoding(self.max_len, self.dim_model)
@@ -47,9 +55,10 @@ def forward(self, x) -> nn.Embedding: # TODO: type of x
class Encoder(nn.Module):
"""Base class for transformer encoders
"""

def __init__(self):
super().__init__()
self.config = get_config('model')
self.config = get_config("model")
self.num_layers = self.config.num_encoder_layer
self.embeddings = Embeddings()
self.mha = MultiHeadAttention(attention_mask=False)
@@ -79,6 +88,7 @@ class Decoder(nn.Module):
Attributes:
target_embeddings:
"""

def __init__(self, target_embeddings: nn.Embedding):
super().__init__()
self.target_embeddings = target_embeddings
6 changes: 4 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pathlib import Path
from typing import List, Optional

from tokenizers import SentencePieceBPETokenizer
from omegaconf import DictConfig, OmegaConf
from tokenizers import SentencePieceBPETokenizer


def read_lines(filepath: str) -> List[str]:
@@ -77,7 +77,9 @@ def get_config(arg: str, langpair: Optional[str] = None) -> DictConfig:
)

if arg == "model":
config_path = config_dir / "transformer-base.yaml" # TODO: support transformer-big.yaml
config_path = (
config_dir / "transformer-base.yaml"
) # TODO: support transformer-big.yaml
else:
langpair = normalize_langpair(langpair)
config_path = list(config_dir.glob(f"*{langpair}*"))[0]