![0a9d511a7e369cab6753145f8409797d.png](https://i-blog.csdnimg.cn/blog_migrate/3b9bd4c953bc5138beffb4628c7858b8.jpeg)
好久不更新,国庆节庆祝一下~
提到全连接神经网络相信大家应该都不会觉得陌生(不陌生你点进来干嘛[捂脸]),本文就全连接神经网络的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。
完整实现代码请参考本人的p...哦不是...github:
tushushu/imylugithub.com![ebb7147d16101296a44b68cd5deb9f06.png](https://i-blog.csdnimg.cn/blog_migrate/950b2dc995a4442bd042271b63d7ad27.jpeg)
1. 原理篇
我们用人话而不是大段的数学公式来讲讲全连接神经网络是怎么一回事。
1.1 网络结构
灵魂画师用PPT画个粗糙的网络结构图如下:
![f6ecb046439ba47869105fa8da06665f.png](https://i-blog.csdnimg.cn/blog_migrate/0eee82572fb0d198a962c28aa64ae2f3.jpeg)
1.2 Simoid函数
Sigmoid函数的表达式是:
不难得出:
所以,Sigmoid函数的值域是(0, 1),导数为y * (1 - y)
1.3 链式求导
z = f(y)
y = g(x)
dz / dy = f'(y)
dy / dx = g'(x)
dz / dz = dz / dy * dy / dx = f'(y) * g'(x)
1.4 向前传播
将当前节点的所有输入执行当前节点的计算,作为当前节点的输出节点的输入。
1.5 反向传播
将当前节点的输出节点对当前节点的梯度损失,乘以当前节点对输入节点的偏导数,作为当前节点的输入节点的梯度损失。
1.6 拓扑排序
假设我们的神经网络中有k个节点,任意一个节点都有可能有多个输入,需要考虑节点执行的先后顺序,原则就是当前节点的输入节点全部执行之后,才可以执行当前节点。
2. 实现篇
本人用全宇宙最简单的编程语言——Python实现了全连接神经网络,便于学习和使用。简单说明一下实现过程,更详细的注释请参考本人github上的代码。
2.1 创建BaseNode抽象类
将BaseNode作为各种类型Node的父类。包括如下属性:
1. name -- 节点名称
2. value -- 节点数据
3. inbound_nodes -- 输入节点
4. outbound_nodes -- 输出节点
5. gradients -- 对于输入节点的梯度
class
2.2 创建InputNode类
用于存储训练、测试数据。其中indexes属性用来存储每个Batch中的数据下标。
class
2.3 创建LinearNode类
用于执行线性运算。
1. Y = WX + Bias
2. dY / dX = W
3. dY / dW = X
4. dY / dBias = 1
class
2.4 创建MseNode类
用于计算预测值与实际值的差异。
1. MSE = (label - prediction) ^ 2 / n_label
2. dMSE / dLabel = 2 * (label - prediction) / n_label
3. dMSE / dPrediction = -2 * (label - prediction) / n_label
class
2.5 创建SigmoidNode类
用于计算Sigmoid值。
1. Y = 1 / (1 + e^(-X))
2. dY / dX = Y * (1 - Y)
class
2.6 创建WeightNode类
用于存储、更新权重。
class
2.7 创建全连接神经网络类
class
2.8 网络结构
def
2.9 学习率
存储学习率,并赋值给所有权重节点。
@property
2.10 拓扑排序
实现拓扑排序,将节点按照更新顺序排列。
def
2.11 前向传播和反向传播
def
2.12 建立全连接神经网络
def
2.13 训练模型
使用随机梯度下降训练模型。
def
2.14 移除无用节点
模型训练结束后,将mse和label节点移除。
def
2.15 训练模型
def
2.16 预测多个样本
def
3 效果评估
3.1 main函数
使用著名的波士顿房价数据集,按照7:3的比例拆分为训练集和测试集,训练模型,并统计准确度。
@run_time
3.2 效果展示
拟合优度0.803,运行时间6.9秒。
效果还算不错~
![78691cd31e93ec3c80282ae2b830d26f.png](https://i-blog.csdnimg.cn/blog_migrate/e05042853b1445ff4f5f83a5b748dd39.jpeg)
3.3 工具函数
本人自定义了一些工具函数,可以在github上查看
tushushu/imylugithub.com![ebb7147d16101296a44b68cd5deb9f06.png](https://i-blog.csdnimg.cn/blog_migrate/950b2dc995a4442bd042271b63d7ad27.jpeg)
- run_time - 测试函数运行时间
- load_boston_house_prices - 加载波士顿房价数据
- train_test_split - 拆分训练集、测试集
- get_r2 - 计算拟合优度
总结
矩阵乘法
链式求导
拓扑排序
梯度下降