四、用简单神经网络识别手写数字(内含代码详解及订正)

本博客主要内容为图书《神经网络与深度学习》和National Taiwan University (NTU)林轩田老师的《Machine Learning》的学习笔记,因此在全文中对它们多次引用。初出茅庐,学艺不精,有不足之处还望大家不吝赐教。

1. 前期准备

1.1 数据集

  MNIST数据集是基于NIST(美国国家标准与技术研究院)收集的两个数据集合。为了构建MNIST,NIST数据集合被Yann LeCun,Corinna Cortes和Christopher J. C. Burges拆分放入一个更方便的格式,本文所使用的数据集是在一种更容易在Python中加载和操纵MNIST数据的形式。从蒙特利尔大学的LISA机器学习实验室获得了这个特殊格式的数据。这个数据可以通过克隆这本书的代码仓库获得数据;如果你不使用git,那么你能够在这里下载数据和代码。需要注意的是在这里的代码的版本是基于python2的,如果使用python3的用户直接运行这个代码会出现错误,在本文的后续给出经过经过修改后的基于python3 的代码版本。
  在这个数据集中包含60,000个训练图像和10,000个测试图像,在训练的过程中将测试集testing data)保持不变,将训练集(training data)分为由50000个图像组成的训练集以及由剩下的10000个图像组成的验证集validation set)。验证集对于解决如何去设置神经网络中的超参数hyper-parameter)是十分有用的。因此以后提到「MNIST训练数据」指的不是原始的60,000图像数据集,而是我们的50,000图像数据集。

1.2 软件的配置

  除了MNIST数据之外,我们还需要一个叫做Numpy的用于处理快速线性代数的Python库。考虑到今后会更加深入的研究神经网络,因此本文给出配置Keras的方法,其中几乎已经涵盖了今后会用到的绝大多数的库。在这里的方式是默认这是一台全新的电脑从头配置Keras,已经配置过的用户可以跳过此步。

2. 搭建分类用神经网络

  为了识别数字,我们将会使用一个如图1的三层神经网络


图1. 用于手写数字分类的神经网络

  这个网络的输入层是用于训练数据的 28×28 的手写数字位图,因此我们的输入层包含了 28×28=784 个神经元。为了方便起见,在图1中没有完全画出 784 个输入神经元。输入的像素点是其灰度值,0.0 代表白色,1.0 代表黑色,中间值表示不同程度的灰度值。
  在网络的第二层,即隐层中设置了 n 个神经元, n 是一个需要通过实验确定的值,在这里取 n 的值为15。
  网络的输出层包含了10个神经元,把输出层神经元依次标记为 0 到 9,找到拥有最高的激活值的神经元,将它的标记作为神将网络的结果进行输出。例如 6 号神经元有最高值,那么我们的神经网络预测输入数字是 6,对其它的神经元也如此。
  为什么用 10 个输出神经元,如果采用二进制编码的形式只需要4个神经元即可。最终的判断是基于经验主义的:我们可以实验两种不同的网络设计,结果证明对于这个特定的问题而言,10 个输出神经元的神经网络比 4 个的识别效果更好。理解这一现象的原因可以参考第一篇博客中第二小节“迈向深度学习”的内容,神经网络是通过一部分一部分学习的,因此把数字的最高有效位和数字的形状联系起来并不是一个简单的问题。很难想象出有什么恰当的历史原因,一个数字的形状要素会和一个数字的最高有效位有什么紧密联系。

3. 手写数字分类的主要代码

  这里是用python搭建神经网络识别手写数字的全部代码,实际上程序只包含74行非空行、非注释代码。在这里面主体是一个叫做Network的类,在这个类里面定义了正向传播、反向传播、随机梯度下降等方法,其次是一些辅助函数,即sigmoid函数与sigmoid导数计算函数,具体方法如下所示

def init(self, sizes):初始化
def feedforward(self, a ): 前馈
def SGD(self, training_data, epochs, mini_batch_size, eta,test_data=None) : 随机梯度下降函数
def update_mini_batch(self, mini_batch, eta) : 小块训练集更新参数
def backprop(self, x, y) : 反向传播
def cost_derivative(self, output_activations, y) : 损失函数在输出层的导数(微商)
def evaluate(self, test_data) : 返回神经网络分类正确的总个数
def sigmoid(z) : sigmoid函数值
def sigmoid_prime(z) : sigmoid导数值

在上述这些方法中存在调用关系,将他们的调用关系绘制出来如图2所示


图2. 神经网络中函数相互调用关系

  因此我们从左至右一次对这就个方法进行详尽的分析

3.1 辅助函数

#### Miscellaneous functions
def sigmoid(z):
    """The sigmoid function."""
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    """Derivative of the sigmoid function."""
    return sigmoid(z)*(1-sigmoid(z))

  这两个函数是计算sigmoid函数及导数值的方法,十分容易理解,因此不做分析

3.2 cost_derivative()

def cost_derivative(self, output_activations, y):
     """Return the vector of pa
  • 8
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值