本文章内容:
- 神经网络的搭建:通过定义全连接操作来简化网络搭建
- 存储模型:在检查点checkpoint保存参数
- 读取模型:通过读取ckpt文件将已训练模型用于预测
一、输入数据
为了方便所跑数据的可视化,我们会将变量全部写入tenserboard中。因为tensorboard中同一路径下的图像显示是叠加的,因此可以找到存放路径,将曾经的一些无用图删除。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#载入数据集
# 定义占位符
tf.reset_default_graph() #清除default graph和不断增加的节点
x = tf.placeholder(tf.float32,[None,784],name="X")#定义两个变量
y = tf.placeholder(tf.float32,[None,10],name="Y")
image_shaped_input = tf.reshape(x,[-1,28,28,1])
keep_prob = tf.placeholder(tf.float32) #.
二、建立模型
我们可以定义多层网络,要注意每一层的节点数要对应。
1.建立隐藏层与输出层
# 隐藏层
H1_NN=500
W1 = tf.Variable(tf.truncated_normal([784,H1_NN],stddev=0.1))
b1 = tf.Variable(tf.zeros([H1_NN]))
Y1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
Y1_drop = tf.nn.dropout(Y1,keep_prob) #.
H2_NN=300
W2 = tf.Variable(tf.truncated_normal([H1_NN,H2_NN],stddev=0.1))
b2 = tf.Variable(tf.zeros([H2_NN]))
Y2 = tf.nn.tanh(tf.matmul(Y1_drop,W2)+b2)
Y2_drop = tf.nn.dropout(Y2,keep_prob)
H3_NN=500
W3 = tf.Variable(tf.truncated_normal([H2_NN,H3_NN],stddev=0.1))
b3 = tf.Variable(tf.zeros([H3_NN]))
Y3 = tf.nn.tanh(tf.matmul(Y2_drop,W3)+b3)
Y3_drop = tf.nn.dropout(Y3,keep_prob)
# 输出层
WW = tf.Variable(tf.truncated_normal([H3_NN,10],stddev=0.1))
bb = tf.Variable(tf.zeros([10]))
forward = tf.nn.relu(tf.matmul(Y3_drop, WW) + bb)
pred = tf.nn.softmax(forward)
2.参数与函数设置
# 定义损失函数
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=pred))
# 设置训练参数
train_epochs=50
batch_size=100
total_batch=int(mnist.train.num_examples/batch_size)
display_step=1
learning_rate=0.15
# 选择优化器
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
# 定义准确率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))#argmax返回一维张量中最大的值所在的位置
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
tf.summary.image('input', image_shaped_input,10)
tf.summary.histogram('forward',forward)
tf.summary.scalar('loss',loss_function)
tf.summary.scalar('accuracy',accuracy)
merged_summary_op = tf.summary.merge_all() #合并所有summary
3.训练模型
%timeit
from time import time
startTime=time()
#初始化变量
init = tf.global_variables_initializer()
ckpt_dir = "./ckpt_dir/"
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
writer = tf.summary.FileWriter('log/hide_neural_7',sess.graph)
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:xs,y:ys,keep_prob:0.8})
summary_str = sess.run(merged_summary_op,feed_dict={x:xs,y:ys,keep_prob:0.8})
writer.add_summary(summary_str, epoch)
loss,acc=sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,
y:mnist.validation.labels,
keep_prob:0.8})
if(epoch+1)%display_step==0:
print("Train epoch:","%02d"%(epoch+1),
"Loss=","{:.9f}".format(loss),
"Accuracy=","{:.4f}".format(acc))
# if(epoch+1)%5==0:
# saver.save(sess,os.path.join(SAVE_PATH,"mnist_model"),global_step=global_step)
duration=time()-startTime
print("Train finished takes:","{:.2f}".format(duration))
# if((epoch+1)%10==0):
# saver.save(sess, "Neural/model.ckpt")
# 显示运行总时间
duration =time()-startTime
print("Train Finished takes:","{:.2f}".format(duration))
saver.save(sess, os.path.join(ckpt_dir, 'mnist_h256_h256_model.ckpt'))#生成检查点文件
print("Model saved!")
4.评估与预测
import os
# 评估
acc_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:0.8})
print("Test Accuracy:",acc_test)
# 预测
prediction_result=sess.run(tf.argmax(pred,1),
feed_dict={x:mnist.test.images,keep_prob:0.8})
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,
prediction,idx,num=10):
fig = plt.gcf()
fig.set_size_inches(10, 12)
if num>25: num=25
for i in range(0, num):
ax=plt.subplot(5,5, 1+i)
ax.imshow(np.reshape(images[idx],(28, 28)),
cmap='binary')
title= "label=" +str(np.argmax(labels[idx]))
if len(prediction)>0:
title+=",predict="+str(prediction[idx])
ax.set_title(title,fontsize=10)
ax.set_xticks([]);ax.set_yticks([])
idx+=1
plt.show()
plot_images_labels_prediction(mnist.test.images,
mnist.test.labels,
prediction_result,10,25)
5.读取模型
print("Starting another session for prediction")
saver = tf.train.Saver()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)# 从已保存的模型中读取参数
print ("Accuracy:", accuracy.eval(session=sess, feed_dict={x: mnist.test.images, y: mnist.test.labels}))