-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfigs.py
28 lines (24 loc) · 884 Bytes
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""This module contains the configuration classes for a base instance of a decoder
transformer model that I will then port from an NLP implementation to a neural implementation"""
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
"""This class contains the configuration for the model"""
d_model: int = 256
d_head: int = 64
d_mlp: int = d_model * 4
n_blocks: int = 2 # 2 layers per block (Attention + MLP)
n_heads: int = d_model // d_head
n_ctx: int = 256
init_range: float = 0.02 # std for initializing weights
dropout: float = 0.1
@dataclass
class TransformerTrainingArgs():
batch_size = 16
epochs = 10
max_steps_per_epoch = 200
lr = 1e-3
weight_decay = 1e-2
wandb_project: Optional[str] = "base_transformer"
wandb_name: Optional[str] = None