-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
78 lines (64 loc) · 2.32 KB
/
train.py
File metadata and controls
78 lines (64 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
@author : seauagain
@date : 2025.11.01
"""
import os, sys, time
import torch
from src.config import default_parser
from src.utils import initial_distributed_logger, setup_current_time, setup_device, time_cost, dict2attr
from src.trainer import Trainer
@time_cost("train_profile.txt")
def main():
# 配置参数
# config = {
# 'lr': 1e-3,
# 'batch_size': 32,
# 'max_epochs': 3,
# 'warmup_epochs': 5,
# 'use_warmup': True,
# 'device': 'cuda:0',
# "seed": 42,
# "model_root": "results",
# "model_name": "test_transformer",
# "current_time": "",
# "validloss_interval":10,
# "saveloss_interval":10,
# "saveckpt_interval":100,
# }
config = default_parser().parse_args()
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("tokenizer/Helsinki-NLP/opus-mt-zh-en")
tokenizer.add_special_tokens({'bos_token': '<bos>'})
vocab = tokenizer.get_vocab()
# config = dict2attr(config)
config.en_vocab_size = len(vocab)
config.zh_vocab_size = len(vocab)
config.d_model = 512
config.d_ff = 1024
config.max_seq_length = 5000
config.dropout = 0.1
config.nums_heads = 8
config.num_layers = 6
config.train_data_path = "./dataset/translation2019zh_train.json"
config.valid_ratio = 0.1
config.init_lr = 1e-4
# from src.data.dataloader import get_train_val_loader
# train_loader, val_loader, en_vocab, zh_vocab, special_tokens = get_train_val_loader(config.train_data_path, batch_size=config.batch_size, val_split=config.valid_ratio)
config.src_pad_idx = tokenizer.pad_token_id
config.trg_pad_idx = tokenizer.pad_token_id
config.trg_bos_idx = tokenizer.bos_token_id
config.trg_eos_idx = tokenizer.eos_token_id
import os
model_dir = os.path.join(config.model_root, config.model_name)
os.makedirs(model_dir, exist_ok=True)
setup_device(config)
setup_current_time(config)
logger = initial_distributed_logger(config) ## steup logfile path
logger.logging_args(config) ## print hyper-parameters in logfile
trainer = Trainer(config, logger)
trainer.initialize(config)
trainer.training_entrance(config)
if __name__ == '__main__':
main()
## torchrun --nprocnode 2 train.py
## python train.py