Skip to content

Commit 1f69005

Browse files
committed
fix: config bug
1 parent 739c601 commit 1f69005

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

src/model/decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def __init__(self, is_base: bool = True):
2424

2525
self.masked_mha = MultiHeadAttention(masked_attention=True)
2626
self.mha = MultiHeadAttention(masked_attention=False)
27-
self.ln = LayerNorm(self.config.train_hparams.eps)
27+
self.ln = LayerNorm(self.config.model.train_hparams.eps)
2828
self.ffn = FeedForwardNetwork()
29-
self.residual_dropout = nn.Dropout(p=self.config.model_params.dropout)
29+
self.residual_dropout = nn.Dropout(p=self.config.model.model_params.dropout)
3030

3131
def attention_mask(self, batch_size: int, seq_len: int) -> Tensor:
3232
attention_shape = (batch_size, seq_len, seq_len)
@@ -76,7 +76,7 @@ def __init__(self, langpair: str, is_base: bool = True) -> None:
7676
self.embedding = Embeddings(langpair)
7777
self.config = Config()
7878
self.config.add_model(is_base)
79-
self.num_layers = self.config.model_params.num_decoder_layer
79+
self.num_layers = self.config.model.model_params.num_decoder_layer
8080
self.decoder_layers = get_clones(DecoderLayer(), self.num_layers)
8181

8282
def forward(

src/model/encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def __init__(self, is_base: bool = True):
3636
self.config.add_model(is_base)
3737

3838
self.mha = MultiHeadAttention(masked_attention=False)
39-
self.attention_dropout = nn.Dropout(p=self.config.model_params.dropout)
40-
self.ln = LayerNorm(self.config.train_hparams.eps)
39+
self.attention_dropout = nn.Dropout(p=self.config.model.model_params.dropout)
40+
self.ln = LayerNorm(self.config.model.train_hparams.eps)
4141
self.ffn = FeedForwardNetwork()
42-
self.residual_dropout = nn.Dropout(p=self.config.model_params.dropout)
42+
self.residual_dropout = nn.Dropout(p=self.config.model.model_params.dropout)
4343

4444
def forward(self, source_emb: Tensor, source_mask: Tensor) -> Tuple[Tensor, Tensor]:
4545
source_emb = source_emb + self.mha(
@@ -64,7 +64,7 @@ def __init__(self, langpair: str, is_base: bool = True) -> None:
6464
self.embedding = Embeddings(langpair)
6565
self.config = Config()
6666
self.config.add_model(is_base)
67-
self.num_layers = self.config.model_params.num_encoder_layer
67+
self.num_layers = self.config.model.model_params.num_encoder_layer
6868
self.encoder_layers = get_clones(EncoderLayer(), self.num_layers)
6969

7070
def forward(self, source_tokens: Tensor, source_mask: Tensor) -> NamedTuple:

src/model/modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(self, langpair: str, is_base: bool = True) -> None:
5252
super().__init__()
5353
# TODO: support transformer-base and transformer-big
5454
configs = Config()
55-
configs.add_model(is_base).add_tokenizer(langpair)
55+
configs.add_model(is_base)
56+
configs.add_tokenizer(langpair)
5657
tokenizer = load_tokenizer(langpair)
5758
padding_idx = tokenizer.token_to_id("<pad>")
5859

0 commit comments

Comments
 (0)