-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathhyperparameters.py
67 lines (42 loc) · 1.48 KB
/
hyperparameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 12 14:23:12 2018
@author: cm
"""
import os
import sys
pwd = os.path.dirname(os.path.abspath(__file__))
sys.path.append(pwd)
from classifier_multi_label_denses.utils import load_vocabulary
class Hyperparamters:
# Train parameters
num_train_epochs = 5
print_step = 100
batch_size = 8#64
summary_step = 10
num_saved_per_epoch = 3
max_to_keep = 100
logdir = 'logdir/CML_Denses'
file_save_model = 'model/CML_Denses'
inference_model = 'CML_Denses'
# Train/Test data
data_dir = os.path.join(pwd,'data')
train_data = 'train_onehot.csv'
test_data = 'test_onehot.csv'
# Load vocabulcary dict
dict_id2label,dict_label2id = load_vocabulary(os.path.join(pwd,'data','vocabulary_label.txt') )
label_vocabulary = list(dict_id2label.values())
# Optimization parameters
warmup_proportion = 0.1
use_tpu = None
do_lower_case = True
learning_rate = 5e-5
# Sequence and Label
sequence_length = 60
num_labels = len(list(dict_id2label))
# ALBERT
model = 'albert_small_zh_google'
bert_path = os.path.join(pwd,model)
vocab_file = os.path.join(pwd,model,'vocab_chinese.txt')
init_checkpoint = os.path.join(pwd,model,'albert_model.ckpt')
saved_model_path = os.path.join(pwd,'model')