Tensorflow入门(8)——多层网络手写识别

本文章内容:

  • 神经网络的搭建:通过定义全连接操作来简化网络搭建
  • 存储模型:在检查点checkpoint保存参数
  • 读取模型:通过读取ckpt文件将已训练模型用于预测

一、输入数据

为了方便所跑数据的可视化,我们会将变量全部写入tenserboard中。因为tensorboard中同一路径下的图像显示是叠加的,因此可以找到存放路径,将曾经的一些无用图删除。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#载入数据集

# 定义占位符 
tf.reset_default_graph() #清除default graph和不断增加的节点

x = tf.placeholder(tf.float32,[None,784],name="X")#定义两个变量
y = tf.placeholder(tf.float32,[None,10],name="Y")

image_shaped_input = tf.reshape(x,[-1,28,28,1])
keep_prob = tf.placeholder(tf.float32)   #.

二、建立模型

我们可以定义多层网络,要注意每一层的节点数要对应。

1.建立隐藏层与输出层

# 隐藏层
H1_NN=500
W1 = tf.Variable(tf.truncated_normal([784,H1_NN],stddev=0.1))
b1 = tf.Variable(tf.zeros([H1_NN]))
Y1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
Y1_drop = tf.nn.dropout(Y1,keep_prob)  #.

H2_NN=300
W2 = tf.Variable(tf.truncated_normal([H1_NN,H2_NN],stddev=0.1))
b2 = tf.Variable(tf.zeros([H2_NN]))
Y2 = tf.nn.tanh(tf.matmul(Y1_drop,W2)+b2)
Y2_drop = tf.nn.dropout(Y2,keep_prob)

H3_NN=500
W3 = tf.Variable(tf.truncated_normal([H2_NN,H3_NN],stddev=0.1))
b3 = tf.Variable(tf.zeros([H3_NN]))
Y3 = tf.nn.tanh(tf.matmul(Y2_drop,W3)+b3)
Y3_drop = tf.nn.dropout(Y3,keep_prob)

# 输出层
WW = tf.Variable(tf.truncated_normal([H3_NN,10],stddev=0.1))
bb = tf.Variable(tf.zeros([10]))

forward = tf.nn.relu(tf.matmul(Y3_drop, WW) + bb)
pred = tf.nn.softmax(forward)

2.参数与函数设置

# 定义损失函数
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=pred))

# 设置训练参数
train_epochs=50
batch_size=100
total_batch=int(mnist.train.num_examples/batch_size)
display_step=1
learning_rate=0.15

# 选择优化器
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

# 定义准确率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))#argmax返回一维张量中最大的值所在的位置
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))


tf.summary.image('input', image_shaped_input,10)
tf.summary.histogram('forward',forward)
tf.summary.scalar('loss',loss_function)
tf.summary.scalar('accuracy',accuracy)
merged_summary_op = tf.summary.merge_all()  #合并所有summary

3.训练模型

%timeit
from time import time
startTime=time()

#初始化变量
init = tf.global_variables_initializer()

ckpt_dir = "./ckpt_dir/"
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
writer = tf.summary.FileWriter('log/hide_neural_7',sess.graph)

for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)
        sess.run(train_step,feed_dict={x:xs,y:ys,keep_prob:0.8})
        summary_str = sess.run(merged_summary_op,feed_dict={x:xs,y:ys,keep_prob:0.8})
        writer.add_summary(summary_str, epoch)

    loss,acc=sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,
                                                          y:mnist.validation.labels,
                                                          keep_prob:0.8})
        
    if(epoch+1)%display_step==0:
        print("Train epoch:","%02d"%(epoch+1),
             "Loss=","{:.9f}".format(loss),
              "Accuracy=","{:.4f}".format(acc))
#     if(epoch+1)%5==0:
#         saver.save(sess,os.path.join(SAVE_PATH,"mnist_model"),global_step=global_step)
    duration=time()-startTime
    print("Train finished takes:","{:.2f}".format(duration))
#     if((epoch+1)%10==0):
#         saver.save(sess, "Neural/model.ckpt")

# 显示运行总时间    
duration =time()-startTime
print("Train Finished takes:","{:.2f}".format(duration))

saver.save(sess, os.path.join(ckpt_dir, 'mnist_h256_h256_model.ckpt'))#生成检查点文件
print("Model saved!")

4.评估与预测

import os

# 评估
acc_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:0.8})
print("Test Accuracy:",acc_test)
# 预测
prediction_result=sess.run(tf.argmax(pred,1),
                           feed_dict={x:mnist.test.images,keep_prob:0.8})

import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(10, 12)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)
        
        ax.imshow(np.reshape(images[idx],(28, 28)), 
                  cmap='binary')
            
        title= "label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx]) 
            
        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()
    
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,10,25)

5.读取模型

print("Starting another session for prediction")
saver = tf.train.Saver()

sess = tf.Session() 
init = tf.global_variables_initializer() 
sess.run(init)

ckpt = tf.train.get_checkpoint_state(ckpt_dir)

if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)# 从已保存的模型中读取参数
print ("Accuracy:", accuracy.eval(session=sess, feed_dict={x: mnist.test.images, y: mnist.test.labels}))       

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值