一、概述
本文的推导参见西瓜书P102~P103,代码参见该网址。主要实现了利用三层神经网络进行手写数字的识别。
二、理论推导
1、参数定义
三层神经网络只有一层隐藏层。参数如下:
x | 输入层输入 |
v | 输入层与隐藏层间的权值 |
α | 隐藏层输入 |
b | 隐藏层输出 |
w |
隐藏层与输出层间的权值 |
输出层输入 | |
输出层输出 |
参数关系如下:
上述等式中fx为激活函数。西瓜书默认激活函数为sigmoid,损失函数为均方根,本文以此为前提进行推导。
2、推导
设损失函数为,则其公式如下:
对求偏导如下,这愚蠢的CSDN不支持多行公式编辑,所以只好手写了:
于是我们就得到了的更新公式:
同样的,对求偏导如下:
于是我们就得到了的更新公式:
三、代码实现
优化方法选择SGD。
1、数据集初始化
MNIST数据集可以在TensorFlow中下载到:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)
mnist对象中就存储着所有的数据,其中,mnist.train.images为50000*784的二维array;储存着训练集的输入,每一行储存着784个像素;mnist.train.labels为50000*10的二维array;储存着训练集的标记,每一行中为1的列对应着该行的label。
2、初始化神经网络类
class NeuralNet:
def __init__(self,InputNum,HiddenNum,OutputNum,LearnRate):
self.InNum=InputNum#输入层节点数
self.HiNum=HiddenNum#隐藏层节点数