使用Tensorflow处理Mnist手写数据集
Mnist手写数据集是一个入门级的计算机视觉数据集,何谓入门呢?可以这样说,MNIST 问题就相当于图像处理的 Hello World 程序。下面我将使用Tensorflow搭建CNN卷积神经网络来处理MNIST数据集,来一步步的熟悉Tensorflow和CNN。
MNIST数据集介绍
MNIST数据集是一个手写体数据集,简单说就是一堆这样东西:
MNIST的官网地址是MNIST; 通过阅读官网我们可以知道,这个数据集由四部分组成,分别是
;
当然下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train
)和10000行的测试数据集(mnist.test
)。这样的切分很重要,在机器学习模型设计时必须有一个单独的测试数据集不用于训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上(泛化)。
我们可以看出这个其实并不是普通的文本文件或是图片文件,而是一个压缩文件,下载并解压出来,里面看到的是二进制文件。
一张图片包含28像素X28像素。我们可以用一个数字数组来表示这张图片:
我们把这个数组展开成一个向量,长度是 28x28 = 784。如何展开这个数组(数字间的顺序)不重要,只要保持各个图片采用相同的方式展开。从这个角度来看,MNIST数据集的图片就是在784维向量空间里面的点。
因此,在MNIST训练数据集中,mnist.train.images
是一个形状为 [60000, 784]
的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。
相对应的MNIST数据集的标签是介于0到9的数字,用来描述给定图片里表示的数字。比如,标签0将表示成([1,0,0,0,0,0,0,0,0,0,0])。因此, mnist.train.labels
是一个 [60000, 10]
的数字矩阵。
输入集
得到数据集可以直接使用下面代码,程序会直接下载代码到MNIST_data文件夹中,或者在MNIST数据集的官网Yann LeCun’s website下载。
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
使用卷积神经网络处理数据集
卷积神经网络介绍
神经网络由大量的神经元相互连接而成。每个神经元接受线性组合的输入后,最开始只是简单的线性加权,后来给每个神经元加上了非线性的激活函数,从而进行非线性变换后输出。每两个神经元之间的连接代表加权值,称之为权重(weight)。不同的权重和激活函数,则会导致神经网络不同的输出。
给定一个未知数字,让神经网络识别是什么数字。此时的神经网络的输入由一组被输入图像的像素所激活的输入神经元所定义。在通过非线性激活函数进行非线性变换后,神经元被激活然后被传递到其他神经元。重复这一过程,直到最后一个输出神经元被激活。从而识别当前数字是什么字,如下图结构所示:
下面推荐一篇博文,我认为介绍神经网络介绍的很详细:通俗理解卷积神经网络
定义计算图
使用 TensorFlow, 你必须明白 TensorFlow:
- 使用图 (graph) 来表示计算任务.
- 在被称之为
会话 (Session)
的上下文 (context) 中执行图. - 使用 tensor 表示数据.
- 通过
变量 (Variable)
维护状态. - 使用 feed 和 fetch 可以为任意的操作(arbitrary operation) 赋值或者从其中获取数据.
下面我们将定义一个多层卷积网络
权重初始化
为了创建这个模型,我们需要创建大量的权重和偏置项。这个模型中的权重在初始化时应该加入少量的噪声来打破对称性以及避免0梯度。由于我们使用的是ReLU神经元,因此比较好的做法是用一个较小的正数来初始化偏置项,以避免神经元节点输出恒为0的问题(dead neurons)。为了不在建立模型的时候反复做初始化操作,我们定义两个函数用于初始化。
# 定义一个函数,用于初始化所有的权值 W
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
# 定义一个函数,用于初始化所有的偏置项 b
def bias_variabls(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
卷积和池化
TensorFlow在卷积和池化上有很强的灵活性。我们怎么处理边界?步长应该设多大?在这个实例里,我们会一直使用vanilla版本。我们的卷积使用1步长(st