-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
32 lines (27 loc) · 869 Bytes
/
Copy pathconfig.py
File metadata and controls
32 lines (27 loc) · 869 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# config.py
# reference : https://youtu.be/ISNdQcPhsts?si=F5xPY5JV92VNdKog
# original code : https://github.com/hkproj/pytorch-transformer/blob/main/config.py
from pathlib import Path
def get_config():
return {
"batch_size": 8,
"num_epochs": 20,
"lr": 1e-4,
"seq_len": 100, # 512
"d_model": 512,
"lang_src": "invocation",
"lang_tgt": "cmd",
"model_folder": "./weights",
"model_basename": "tmodel_",
"preload": None,
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "./tmodel",
"cos_annealing": False,
"beam_search": True,
"beam_width": 3
}
def get_weights_file_path(config):
model_folder = config["model_folder"]
model_basename = config["model_basename"]
model_filename = f"{model_basename}x.pth"
return str(Path('.') / model_folder / model_filename)