05-CNN网络模型之MNIST分类

一、写在前面
笔者前几天已经对之前的需求做了一个简单的CNN模型构建,并成功应用到我们自己的APP上面了,但是在博客的编写过程中发现自己整理的数据集不是很容易用于展示,故这里构建CNN的网络模型使用我们所熟知的MNIST数据集,该数据集在我前面的文章中 有介绍,是一个手写字体的数据集。本篇会介绍简单的CNN网络模型的搭建,以熟悉整个流程。

二、任务目标
1. 构建CNN网络
2. 使用两种模型文件的保存方式保存模型
3. 读取模型的中间结果并训练
4. 正确显示迭代次数

三、任务主要内容
1. 读取数据集

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util
import re
#读取数据集
mnist=input_data.read_data_sets('data/',one_hot=True)
#训练集
train_img=mnist.train.images
print(train_img.shape)
train_labels=mnist.train.labels
print(train_labels.shape)
#测试集
test_img=mnist.test.images
print(test_img.shape)
test_labels=mnist.test.labels
print(test_labels.shape)

这里写图片描述

2. 设置相关参数

#相关参数设置
#学习率
learning_rate=0.01
#训练的迭代次数
train_epochs=100
#每次迭代的样本数量
epoch_size=128

#输入维度为28×28
input_dt=28*28
#输出维度10
output_dt=10

3. 构建模型

x=tf.placeholder(tf.float32,shape=[None,input_dt],name='x')
y=tf.placeholder(tf.float32,shape=[None,output_dt],name='y')
#构建网络
def network(x):
    #1.输入
    with tf.variable_scope('input'):
        nw=tf.reshape(x,[-1,28,28,1])
    #2.卷积层
    with tf.variable_scope('conv2'):
        #在Tensorflow中,步长为1,并且padding为SAME之后,经过卷积之后图像大小是不变的,仅改变通道数
        nw=tf.nn.conv2d(nw,filter=tf.get_variable('w',[5,5,1,50],dtype=tf.float32,initializer=tf.random_normal_initializer(mean=0,stddev=0.01)),strides=[1,1,1,1],padding='SAME')
        nw=tf.nn.bias_add(nw,tf.get_variable('b',[50],dtype=tf.float32,initializer=tf.random_normal_initializer(mean=0,stddev=0.01)))
        #激活 Relu
        nw=tf.nn.relu(nw)
    with tf.variable_scope('pool3'):
        #和conv2一样,需要给定窗口大小和步长
        nw=tf.nn.max_pool(nw,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

    with tf.variable_scope('conv4'):
        nw=tf.nn.conv2d(nw,filter=tf.get_variable('w',[5,5,50,100],dtype=tf.float32,initializer=tf.random_normal_initializer(mean=0,stddev=0.01)),strides=[1,1,1,1],padding='SAME')
        nw=tf.nn.bias_add(nw,tf.get_variable('b',[100],dtype=tf.float32,initializer=tf.random_normal_initializer(mean=0,stddev=0.01)))
        nw=tf.nn.relu(nw)
    with tf.variable_scope('pool5'):
        nw=tf.nn.max_pool(nw,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

    with tf.variable_scope('fc6'):
        #28->14->7
        nw=tf.reshape(nw,shape=[-1,7*7*100])
        nw=tf.add(tf.matmul(nw,tf.get_variable('w',[7*7*100,490],dtype=tf.float32,initializer=tf.random_normal_initializer(mean=0,stddev=0.01))),get_variable('b',[490]))
        nw=tf.nn.relu(nw)
    with tf.variable_scope('fc7'):
        nw=tf.add(tf.matmul(nw,tf.get_variable('w',[490,output_dt],dtype=tf.float32,initializer=tf.random_normal_initializer(mean=0,stddev=0.01))),tf.get_variable('b',[output_dt],dtype=tf.float32,initializer=tf.random_normal_initializer(mean=0,stddev=0.01)))
        pre=tf.nn.softmax(nw)
    return pre

4. 损失函数及梯度下降求解

#构建模型的损失函数
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pre,labels=y))

#使用梯度下降的方式求解
train=tf.train.AdamOptimizer(learning_rate=learning_rate,beta1=0.9,beta2=0.999).minimize(loss)

equal=tf.equal(tf.argmax(pre,axis=1),tf.argmax(y,axis=1))
predict_labels=tf.argmax(pre,axis=1,name='output')
#正确率
accuracy=tf.reduce_mean(tf.cast(equal,tf.float32))

5. 模型训练
这里使用了正则匹配来匹配训练过程中的epoch不能restore的情况,如果有更加方便的用法,欢迎给出。

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    saver=tf.train.Saver()
    #用于重载保存的中间模型
    result=[0]
    ckpt=tf.train.get_checkpoint_state('./model/')
    if ckpt and ckpt.model_checkpoint_path :
        print("model restoring")
        saver.restore(sess,ckpt.model_checkpoint_path)
        print(ckpt.model_checkpoint_path)
        pattern=re.compile('\d+')
        #result=pattern.match(ckpt.model_checkpoint_path)
        result=pattern.findall(ckpt.model_checkpoint_path)
        print(result[0])

    for epoch in range(train_epochs-int(result[0])):
        avg_loss=0
        epoch_num=int(mnist.train.examples/epoch_size)
        for batch in range(epoch_num):
            next_x,next_y=mnist.train.next_batch(epoch_size)
            sess.run(train,feed_dict={x:next_x,y:next_y})
            avg_loss+=sess.run(loss,feed_dict={x:next_x,y:next_y})
        avg_loss=avg_loss/epoch_num

        #显示信息
        print("epoch:{},loss:{}".format(epoch+1+int(result[0]),avg_loss))
        train_accraucy=sess.run(accuracy,feed_dict={x:train_img,y:train_labels})
        test_accraucy=sess.run(accuracy,feed_dict={x:test_img,y:test_labels})
        print("训练集准确率:%.6f,测试集准确率:%.6f"%(train_accraucy,test_accraucy))

        if (epoch+1)%15==0:
        #.ckpt模型文件保存方式
            saver.save(sess,'./model/model',global_step=epoch+1+int(result[0]))
            print("{} model has saved in ./model/model".format(epoch+1+int(result[0])))
    #.pb模型文件保存方式
    constant_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,["output"])
    with tf.gfile.FastGFile('pb_path/graph.pb',mode='wb') as file:
        file.write(constant_graph.SerializerToString())

四、分析与展示
1. 使用AdamOptimizer的方式可以大大增加迭代的效率,只需进行少量的迭代即可达到满意的效果
这里写图片描述
2. 按照文章中给出的参数可能会导致程序因为显存不足而卡死或者退出,可更改相应的参数来减少训练过程中的数据量。

五、总结
学习过程中有很多坑必须要自己踩才能在未来的代码生涯中少犯同类型的错误,并且知识量的匮乏还有相关概念的理解的缺失很容易陷入误区,在未来希望能加快学习进度,并且完成大量的知识储备,与君共勉。

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值