步骤
1.导入MNIST数据集
2.分析MNIST样本特点定义变量
3.构建模型
4.训练模型并输出中间状态参数
5.测试模型
6.保存模型
7.读取模型
导入MNIST数据集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
print('输入数据:',mnist.train.images)
print("输入数据打印shape:",mnist.train.images.shape)
import pylab
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
结果:
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
输入数据: [[0. 0. 0. … 0. 0. 0.]
[0. 0. 0. … 0. 0. 0.]
[0. 0. 0. … 0. 0. 0.]
…
[0. 0. 0. … 0. 0. 0.]
[0. 0. 0. … 0. 0. 0.]
[0. 0. 0. … 0. 0. 0.]]
输入数据打印shape: (55000, 784)
分析:
代码中的one_hot = True,表示将样本标签转化为one_hot编码。
MNIST数据集中的图片是28×28Pixel,所以,每一幅图就是1行784(28*28)列的数据。
MNIST数据集包括3部分,一部分是训练集(mnist.train.images),一部分是测试集(mnist.test.images),还有一部分就是验证数据集(mnist.validation.images)。
分析图片的特点,定义变量
由于输入图片是个550000*784的矩阵,所以先创建一个[None,784]的占位符x和一个[None,10]的占位符y,然后使用feed机制将图片和标签输入进去。
tf.reset_default_graph()
#定义占位符
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
构建模型
1.定义学习参数
W = tf.Variable(tf.random_normal([784,10]))
b = tf.Variable(tf.zeros([10]))
在这里赋予tf.Variable不同的初值来创建不同的参数。一般将W设为一个随机值,将b设为0。
2.定义输出节点
pred = tf.nn.softmax(tf.matmul(x,W)+b) #softmax分类
3.定义反向传播结构
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
分析:
1.将生成的pred与样本标签y进行一次交叉熵的运算,然后取平均值。
2.将这个结果作为一次正向传播的误差,通过梯度下降的优化方法找到能够使这个误差最小化的b和W的偏移量。
3.更新b和W,使其调整为合适的参数。
训练模型并输出中间状态参数
training_epochs = 25
batch_size = 100
display_step = 1
#启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) #Initializing OP
#启动循环开始训练
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
_,c = sess.run([optimizer,cost],feed_dict={x:batch_xs,y:batch_ys})
avg_cost+=c/total_batch
if (epoch+1) % display_step == 0:
print("Epoch:",'%04d'%(epoch+1),"cost=","{:.9f}".format(avg_cost))
print("Finished!")
结果:
Epoch: 0001 cost= 8.555878963
Epoch: 0002 cost= 4.728819415
Epoch: 0003 cost= 3.253529594
Epoch: 0004 cost= 2.524299589
Epoch: 0005 cost= 2.099848258
Epoch: 0006 cost= 1.826618854
Epoch: 0007 cost= 1.637354129
Epoch: 0008 cost= 1.499148029
Epoch: 0009 cost= 1.393936850
Epoch: 0010 cost= 1.310546904
Epoch: 0011 cost= 1.243091764
Epoch: 0012 cost= 1.187183216
Epoch: 0013 cost= 1.139807938
Epoch: 0014 cost= 1.099148231
Epoch: 0015 cost= 1.063529595
Epoch: 0016 cost= 1.032225559
Epoch: 0017 cost= 1.004128178
Epoch: 0018 cost= 0.979137269
Epoch: 0019 cost= 0.956363070
Epoch: 0020 cost= 0.935628735
Epoch: 0021 cost= 0.916631048
Epoch: 0022 cost= 0.899071389
Epoch: 0023 cost= 0.882968679
Epoch: 0024 cost= 0.867918332
Epoch: 0025 cost= 0.853862665
Finished!
解释:
training_epochs代表要把整个训练样本集迭代25次;batch_size代表在训练过程中一次取100条数据进行训练;display_step代表每训练一次就把具体的中间状态显示出来。
测试模型
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print("Accuracy:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels}))
注意:这个过程仍在session里进行的。
结果:
Accuracy: 0.8296
分析:
测试错误率的算法是:直接判断预测的结果与真实的标签是否相同,如果相同就表明是正确的,然后将正确的个数除以总个数 ,即为正确率。由于是onehot编码,这里使用了tf.argmax函数返回onehot编码中数值为1的那个元素的下标。
保存模型
saver = tf.train.Saver()
model_path = "log/521model.ckpt"
with tf.Session() as sess:
#......训练完后
save_path = saver.save(sess,model_path)
print("Model saved in file:%s" % save_path)
读取模型
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
#恢复模型变量
saver.restore(sess2,model_path)
#测试模型
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
#计算验证集准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print("Accuracy:",accuracy.eval({x:mnist.validation.images,y:mnist.validation.labels}))
output = tf.argmax(pred,1)
batch_xs,batch_ys = mnist.train.next_batch(2)
outputval,predv = sess2.run([output,pred],feed_dict={x:batch_xs,y:batch_ys})
print(outputval,predv,batch_ys)
im = batch_xs[0]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
im = batch_xs[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
结果:
Accuracy: 0.827
[7 4] [[1.0536097e-05 7.2111774e-09 2.1802293e-07 1.8825220e-06 4.4055814e-06
6.6896406e-05 2.3730655e-08 9.8775357e-01 1.4292166e-06 1.2161023e-02]
[5.8091226e-10 3.4779413e-10 3.8763491e-07 1.4986345e-13 9.9974340e-01
1.2125688e-06 1.9221659e-06 1.1998784e-04 1.1682522e-04 1.6278111e-05]] [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]
说明:
第一行是验证集的准确率
第一个数组是输出的预测结果
第二个数组是预测出来的真实输出值
第三个数组是标签值onehot编码表示的7和4。