tensorflow-CNN手写数字识别

CNN主要由3种模块构成:

  • 卷积层
  • 采样层(池化)
  • 全连接层

大致上可以理解为:

  1. 通过第一个卷积层提取最初的特征图
  2. 通过第一个采样层提取最初特征图的关键特征,重构特征图
  3. 通过第二个卷积层对重构的特征图再一次提取出特征图
  4. 通过第二个采样层对再一次提取的特征图再一次提取关键特征,重构特征图
  5. 通过第一个全连接层对最终的特征图进行拉伸分类,即将特征图矩阵群转换成一个向量输入到全连接层得到N个输出。
  6. 由于是10分类任务,我们最终要将N个输出变成10个输出,个人理解为降维,再一次输入到全连接层得到10个输出。

1、引入必要的包并加载MNIST数据集

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)#加载数据,可参考之前的文章介绍

2、声明占位符,声明通用参数函数

x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder(tf.float32,[None,10])
x_image=tf.reshape(x,[-1,28,28,1])
def weight_variable(shape):
    initial=tf.truncated_normal(shape,stddev=0.1)
    return tf.Variable(initial)
def bias_variable(shape):
    initial=tf.constant(0.1,shape=shape)
    return tf.Variable(initial)
def conv2d(x,W):
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')
def max_pool_2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
  • 声明占位符x,y_以方便后面我们输入数据进行训练
  • x_image为x拉伸之后的结果,x原来是一个大矩阵,每一行代表一个图像,总共又28*28=784个像素点即每一行有784个值,我们把x拉伸成一张一张的图片,[-1,28,28,1]中代表图片为28*28大小的,1代表灰白图片即单颜色通道(若为彩色,则为3),-1可由tensorflow自己计算出来,由我们输入的x的行数来决定,即有多少张图片。
  • weight_variable(即卷积核)和bias_variable为我们最终要训练出来的参数矩阵,由于后面多次需要初始化不同的w和b,故写成函数方便调用。
  • conv2d为进行卷积计算,max_pool_2x2给特征图进行池化,在用卷积核进行特征提取后,输出的特征图会被传递至池化层进行特征选择和信息过滤,padding='SAME'表示给特征图边缘填充0像素。

3、构建卷积层

#第一层卷积
w_conv1=weight_variable([5,5,1,32])
b_conv1=bias_variable([32])
h_conv1=tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)
  • 第一层的卷积核设置大小为5*5的矩阵,由于输入数据为单颜色通道的图片(可看成第三维,也称一个通道),故第三个参数为1,我们设置要得到32个特征图,即有32个输出,所以b对应为有32个偏置值的向量。
  • tf.nn.relu为激活函数,得到仅有像素为非负数的矩阵特征图
  • Relu函数图像如下图所示:

  • h_pool1为对特征图进行池化后的特征图(提取关键信息,减少算力)

第二层卷积与第一层卷积大体相似。

#第二层卷积
w_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)
  • 同样使用大小为5*5的卷积核,由于第一层卷积我们得到32个特征图(可看成第三维,也称作3个通道),故第三个参数为32,我们设置要得到64个特征图,即有64个输出,所以b对应为有64个偏置值的向量。

4、设置全连接层对得到的特征图进行分类

#全连接层
w_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])
h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
  • 设置全连接层的目的是对第二层卷积得到的64个特征图进行分类,我们知道这64个特征图来自一张图片,我们将这张图片的64个特征图矩阵拉伸成一个向量,即得到7*7*64个像素点输入到全连接层网络得到一个1024维的向量。
  • tf.nn.dropout为以keep_prob的概率杀死神经元节点,防止网络的过拟合。
  • 之后我们将1024个输出值出入到第二层全连接网络,得到10个输出。
#二层全连接网络,把1024维的向量转换成10维,对应10个类别
w_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
y_conv=tf.matmul(h_fc1_drop,w_fc2)+b_fc2
  • 得到一个长度为10的向量,这时候若向量为[0,0,1,0,0,0,0,0,0,0]则表示数字2。
  • 到这里我们完成了前向网络的构建,之后我们将进行反向传播调整参数。

5、反向传播

cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y_conv))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  • 定义交叉熵损失,梯度下降优化参数减小损失,定义准确率。
  • tf.argmax函数的作用是将得到的10维向量取最大值为1,其余为0.4。如:tf.argmax([1,5,3])==[0,1,0]
  • tf.cast为数据转化

6、模型迭代训练

sess=tf.InteractiveSession()#定义图会话
sess.run(tf.global_variables_initializer())#变量初始化
for i in range(20000):#训练20000次
    batch=mnist.train.next_batch(50)#每次取50张照片进行训练
    if i%100==0:#每训练100次(5000张图片)报告准确率
        train_accuracy=accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})
        print("step %d,training accuracy %g" % (i,train_accuracy))
    train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

至此,我们完成了对于手写数字的模型训练,按以上的模型我们可以得到99%以上的准确率。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值