forked from kratzert/finetune_alexnet_with_tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_LeNet.py
165 lines (121 loc) · 5.32 KB
/
train_LeNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
With this script you can train LeNet.
Specify the configuration settings at the beginning according to your
problem.
This script was written for TensorFlow 1.0
Author: Xavier Faure
contact: xavier.faure42(at)gmail.com
"""
import os
import numpy as np
import tensorflow as tf
from datetime import datetime
from LeNet import LeNet
from datagenerator import ImageDataGenerator
"""
Configuration settings
"""
# Input params
image_width = 28
image_height = 28
image_channel = 3
# Path to the textfiles for the trainings and validation set
train_file = '../data-notMNIST/train.txt'
val_file = '../data-notMNIST/validation.txt'
# Learning params
learning_rate = 0.001
num_epochs = 20
batch_size = 128
# Network params
num_classes = 10
# How often we want to write the tf.summary data to disk
display_step = 1
# Path for tf.summary.FileWriter and to store model checkpoints
filewriter_path = "./output"
checkpoint_path = "./output"
# Create parent path if it doesn't exist
if not os.path.isdir(checkpoint_path): os.mkdir(checkpoint_path)
# TF placeholder for graph input and output
x = tf.placeholder(tf.float32, [batch_size, image_width, image_height, image_channel])
y = tf.placeholder(tf.float32, [None, num_classes])
keep_prob = tf.placeholder(tf.float32)
# Initialize model
model = LeNet(x, num_classes)
# Link variable to model output
score = model.fc2
# Op for calculating the loss
with tf.name_scope("cross_ent"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = score, labels = y))
# Train op
with tf.name_scope("train"):
# Create optimizer and apply gradient descent to the trainable variables
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.minimize(loss)
# Add the loss to summary
tf.summary.scalar('cross_entropy', loss)
# Evaluation op: Accuracy of the model
with tf.name_scope("accuracy"):
correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Add the accuracy to the summary
tf.summary.scalar('accuracy', accuracy)
# Merge all summaries together
merged_summary = tf.summary.merge_all()
# Initialize the FileWriter
writer = tf.summary.FileWriter(filewriter_path)
# Initialize an saver for store model checkpoints
saver = tf.train.Saver()
# Initalize the data generator seperately for the training and validation set
train_generator = ImageDataGenerator(train_file,
horizontal_flip = False, shuffle = True,
scale_size=(image_width, image_height),
nb_classes=num_classes)
val_generator = ImageDataGenerator(val_file, shuffle = False, scale_size=(image_width, image_height), nb_classes=num_classes)
# Get the number of training/validation steps per epoch
train_batches_per_epoch = np.floor(train_generator.data_size / batch_size).astype(np.int16)
val_batches_per_epoch = np.floor(val_generator.data_size / batch_size).astype(np.int16)
# Start Tensorflow session
with tf.Session() as sess:
# Initialize all variables
sess.run(tf.global_variables_initializer())
# Add the model graph to TensorBoard
writer.add_graph(sess.graph)
print("{} Start training...".format(datetime.now()))
print("{} Open Tensorboard at --logdir {}".format(datetime.now(),
filewriter_path))
# Loop over number of epochs
for epoch in range(num_epochs):
print("{} Epoch number: {}".format(datetime.now(), epoch+1))
step = 1
while step < train_batches_per_epoch:
# Get a batch of images and labels
batch_xs, batch_ys = train_generator.next_batch(batch_size)
# And run the training op
sess.run(train_op, feed_dict={x: batch_xs,
y: batch_ys})
# Generate summary with the current batch of data and write to file
if step%display_step == 0:
s = sess.run(merged_summary, feed_dict={x: batch_xs,
y: batch_ys})
writer.add_summary(s, epoch*train_batches_per_epoch + step)
step += 1
# Validate the model on the entire validation set
print("{} Start validation".format(datetime.now()))
test_acc = 0.
test_count = 0
for _ in range(val_batches_per_epoch):
batch_tx, batch_ty = val_generator.next_batch(batch_size)
acc = sess.run(accuracy, feed_dict={x: batch_tx,
y: batch_ty})
test_acc += acc
test_count += 1
test_acc /= test_count
print("{} Validation Accuracy = {}".format(datetime.now(), test_acc))
# Reset the file pointer of the image data generator
val_generator.reset_pointer()
train_generator.reset_pointer()
print("{} Saving checkpoint of model...".format(datetime.now()))
#save checkpoint of the model
checkpoint_name = os.path.join(checkpoint_path, 'model_epoch'+str(epoch+1)+'.ckpt')
save_path = saver.save(sess, checkpoint_name)
print("{} Model checkpoint saved at {}".format(datetime.now(), checkpoint_name))