|
| 1 | +""" |
| 2 | +@author: xiao-data |
| 3 | +""" |
| 4 | +import tensorflow as tf |
| 5 | +import numpy as np |
| 6 | +from tensorflow.examples.tutorials.mnist import input_data |
| 7 | +mnist = input_data.read_data_sets("./data/", one_hot=True) |
| 8 | +batch_size = 128 |
| 9 | +learning_rate = 1e-3 |
| 10 | +display_step = 10 |
| 11 | +test_step = 500 |
| 12 | +num_steps = 50000 |
| 13 | +dropout = 0.5 |
| 14 | +l2_lambda = 1e-5 |
| 15 | + |
| 16 | +X = tf.placeholder(tf.float32, [None, 28*28]) |
| 17 | +Y = tf.placeholder(tf.float32, [None, 10]) |
| 18 | +keep_prob = tf.placeholder(tf.float32) # dropout (keep probability) |
| 19 | + |
| 20 | +def conv2d(x, W, b, strides=1): |
| 21 | + x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='VALID') |
| 22 | + x = tf.nn.bias_add(x, b) |
| 23 | +# return tf.nn.relu(x) |
| 24 | + return tf.maximum(0.1*x,x) #leaky relu |
| 25 | + |
| 26 | +def maxpool2d(x, k=2): |
| 27 | + return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='VALID') |
| 28 | + |
| 29 | +def fc(x, W, b): |
| 30 | + x = tf.add(tf.matmul(x, W) , b) |
| 31 | + return tf.maximum(0.1*x,x) |
| 32 | +# return tf.nn.relu(x) |
| 33 | +# return tf.nn.tanh(x) |
| 34 | + |
| 35 | +def lenet(X, weights, biases, dropout): |
| 36 | + X = tf.reshape(X, [-1, 28, 28, 1]) |
| 37 | + X = tf.pad(X, [[0,0],[2,2],[2,2], [0,0]]) |
| 38 | + conv1 = conv2d(X, weights['conv1'], biases['conv1']) |
| 39 | + pool2 = maxpool2d(conv1) |
| 40 | + conv3 = conv2d(pool2, weights['conv3'], biases['conv3']) |
| 41 | + pool4 = maxpool2d(conv3) |
| 42 | + conv5 = conv2d(pool4, weights['conv5'], biases['conv5']) |
| 43 | + conv5 = tf.contrib.layers.flatten(conv5) |
| 44 | + fc6 = fc(conv5, weights['fc6'],biases['fc6']) |
| 45 | + fc7 = fc(fc6, weights['fc7'],biases['fc7']) |
| 46 | + fc7 = tf.nn.dropout(fc7, dropout) |
| 47 | + return fc7 |
| 48 | + |
| 49 | +weights = { |
| 50 | + 'conv1' : tf.Variable(tf.random_normal([5, 5, 1, 6])), |
| 51 | + 'conv3' : tf.Variable(tf.random_normal([5, 5, 6, 16])), |
| 52 | + 'conv5' : tf.Variable(tf.random_normal([5, 5, 16, 120])), |
| 53 | + 'fc6' : tf.Variable(tf.random_normal([120, 84])), |
| 54 | + 'fc7' : tf.Variable(tf.random_normal([84, 10])) |
| 55 | +} |
| 56 | +biases = { |
| 57 | + 'conv1' : tf.Variable(tf.random_normal([6])), |
| 58 | + 'conv3' : tf.Variable(tf.random_normal([16])), |
| 59 | + 'conv5' : tf.Variable(tf.random_normal([120])), |
| 60 | + 'fc6' : tf.Variable(tf.random_normal([84])), |
| 61 | + 'fc7' : tf.Variable(tf.random_normal([10])) |
| 62 | +} |
| 63 | + |
| 64 | +logits = lenet(X, weights, biases, keep_prob) |
| 65 | +prediction = tf.nn.softmax(logits) |
| 66 | + |
| 67 | +loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( |
| 68 | + logits=logits, labels=Y)) |
| 69 | +l2_loss = tf.contrib.layers.apply_regularization(regularizer=tf.contrib.layers.l2_regularizer(l2_lambda), weights_list=tf.trainable_variables()) |
| 70 | +final_loss = loss_op + l2_loss |
| 71 | +optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) |
| 72 | +train_op = optimizer.minimize(final_loss) |
| 73 | + |
| 74 | +correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1)) |
| 75 | +accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) |
| 76 | +init = tf.global_variables_initializer() |
| 77 | + |
| 78 | +with tf.Session() as sess: |
| 79 | + sess.run(init) |
| 80 | + X_test = mnist.test.images[:10000] |
| 81 | + Y_test = mnist.test.labels[:10000] |
| 82 | + for step in range(1, num_steps+1): |
| 83 | + batch_x, batch_y = mnist.train.next_batch(batch_size) |
| 84 | + sess.run(train_op, feed_dict={X: batch_x, Y: batch_y, keep_prob: dropout}) |
| 85 | + if step % display_step == 0 or step == 1: |
| 86 | + pre,loss, acc = sess.run([prediction,loss_op, accuracy], feed_dict={X: batch_x, Y: batch_y, keep_prob: 1.0}) |
| 87 | + print("Step " + str(step) + \ |
| 88 | + ", Minibatch Loss= " + "{:.4f}".format(loss) + \ |
| 89 | + ", Training Accuracy= " + "{:.3f}".format(acc)) |
| 90 | + if step % test_step == 0 and step > 10000: |
| 91 | + print("Test Step "+str(step)+": Accuracy:", \ |
| 92 | + sess.run(accuracy, feed_dict={X: X_test, Y: Y_test,keep_prob: 1.0})) |
| 93 | + print("Optimization Finished!") |
0 commit comments