-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
100 lines (69 loc) · 3.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
"""
Created on Thu May 30 21:42:07 2019
@author: cm
"""
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
import numpy as np
import tensorflow as tf
from classifier_multi_label_denses.networks import NetworkAlbert
from classifier_multi_label_denses.classifier_utils import get_features
from classifier_multi_label_denses.hyperparameters import Hyperparamters as hp
from classifier_multi_label_denses.utils import select,time_now_string
pwd = os.path.dirname(os.path.abspath(__file__))
MODEL = NetworkAlbert(is_training=True)
# Get data features
input_ids,input_masks,segment_ids,label_ids = get_features()
num_train_samples = len(input_ids)
indexs = np.arange(num_train_samples)
num_batchs = int((num_train_samples - 1) /hp.batch_size) + 1
print('Number of batch:',num_batchs)
# Set up the graph
saver = tf.train.Saver(max_to_keep=hp.max_to_keep)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Load model saved before
MODEL_SAVE_PATH = os.path.join(pwd, hp.file_save_model)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print('Restored model!')
with sess.as_default():
# Tensorboard writer
writer = tf.summary.FileWriter(hp.logdir, sess.graph)
for i in range(hp.num_train_epochs):
np.random.shuffle(indexs)
for j in range(num_batchs-1):
# Get ids selected
i1 = indexs[j * hp.batch_size:min((j + 1) * hp.batch_size, num_train_samples)]
# Get features
input_id_ = select(input_ids,i1)
input_mask_ = select(input_masks,i1)
segment_id_ = select(segment_ids,i1)
label_id_ = select(label_ids,i1)
# Feed dict
fd = {MODEL.input_ids: input_id_,
MODEL.input_masks: input_mask_,
MODEL.segment_ids:segment_id_,
MODEL.label_ids:label_id_}
# Optimizer
sess.run(MODEL.optimizer, feed_dict = fd)
# Tensorboard
if j%hp.summary_step==0:
summary,glolal_step = sess.run([MODEL.merged,MODEL.global_step], feed_dict = fd)
writer.add_summary(summary, glolal_step)
# Save Model
if j%(num_batchs//hp.num_saved_per_epoch)==0:
if not os.path.exists(os.path.join(pwd, hp.file_save_model)):
os.makedirs(os.path.join(pwd, hp.file_save_model))
saver.save(sess, os.path.join(pwd, hp.file_save_model, 'model'+'_%s_%s.ckpt'%(str(i),str(j))))
# Log
if j % hp.print_step == 0:
fd = {MODEL.input_ids: input_id_,
MODEL.input_masks: input_mask_,
MODEL.segment_ids:segment_id_,
MODEL.label_ids:label_id_}
loss = sess.run(MODEL.loss, feed_dict = fd)
print('Time:%s, Epoch:%s, Batch number:%s/%s, Loss:%s'%(time_now_string(),str(i),str(j),str(num_batchs),str(loss)))
print('Train finished')