import numpy as np
import os
import matplotlib.pyplot as plt
import pprint
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/' , one_hot=True )
train_img = mnist.train.images
train_lbl = mnist.train.labels
test_img = mnist.test.images
test_lbl = mnist.test.labels
print train_img.shape
lr = 0.01
epoch = 50
batch_size = 100
snapshot = 5
x = tf.placeholder(tf.float32, [None , 784 ], name='input' )
y = tf.placeholder(tf.float32, [None , 10 ], name='groundtruth' )
w = tf.Variable(tf.random_normal([784 , 10 ], stddev=0.5 ))
b = tf.Variable(tf.zeros([1 ,10 ]))
score = tf.matmul(x, w) + b
prob = tf.nn.softmax(score)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(score, y))
optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)
pred = tf.equal(tf.argmax(prob, 1 ), tf.argmax(y,1 ))
acc = tf.reduce_mean(tf.cast(pred, tf.float32))
init = tf.initialize_all_variables()
sess = tf.Session()
with tf.Session() as sess:
sess.run(init)
loss_cache = []
acc_cache = []
for ep in xrange(epoch):
num_batch = mnist.train.num_examples/batch_size
avg_loss = 0
for nb in xrange(num_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
out = sess.run([optimizer, acc, loss], feed_dict={x:batch_x, y:batch_y})
avg_loss += out[2 ]/num_batch
loss_cache.append(avg_loss)
acc_cache.append(out[1 ])
if ep % snapshot ==0 :
print 'Epoch: %d, loss: %.4f, acc: %.4f' %(ep, avg_loss, acc_cache[-1 ])
print 'test accuracy:' , acc.eval({x:test_img, y:test_lbl})
plt.figure(1 )
plt.plot(range(len(loss_cache)), loss_cache, 'b-' , label='loss' )
plt.legend(loc = 'upper right' )
plt.show()
plt.figure(2 )
plt.plot(range (len(acc_cache)), acc_cache, 'o -', label ='acc ')
plt.legend(loc = 'lower right')
plt.show()
# Epoch: 0, loss: 3.1894, acc: 0.3900
# Epoch: 5, loss: 0.7776, acc: 0.8300
# Epoch: 10, loss: 0.6080, acc: 0.8600
# Epoch: 15, loss: 0.5365, acc: 0.8500
# Epoch: 20, loss: 0.4944, acc: 0.9000
# Epoch: 25, loss: 0.4657, acc: 0.8700
# Epoch: 30, loss: 0.4442, acc: 0.9100
# Epoch: 35, loss: 0.4274, acc: 0.9000
# Epoch: 40, loss: 0.4136, acc: 0.8600
# Epoch: 45, loss: 0.4022, acc: 0.9000
# test accuracy: 0.8925