机器学习11:神经网络中的反向传播算法的简单理解(Backpropagation,BP)

本文通过一个简单的神经网络例子,详细解释了反向传播算法的工作原理,利用链式法则计算误差梯度,并应用梯度下降法更新权重。通过逐步计算,展示了如何从预测值与真实值的误差开始,反向更新每一层的权重,从而逐步优化模型,使其预测结果更加接近真实值。
摘要由CSDN通过智能技术生成

前言

在学习神经网络的时候,一直不理解反向传播算法,看到一大堆公式,总是一头雾水。看了很多大佬的分析和解说,最后通过自己一步步的分析和理解发现,相比直接看公式和方程,似乎一个简单的例子更容易理解这种看似复杂实际上并不复杂的算法,下面主要通过一个简单的例子来表达我对BP(反向传播算法)的理解。

链式法则(乘法法则)

如果已经理解链式法则的可以忽略这一部分。

链式法则可以说是反向传播算法的重中之重,几乎每一步的运算都会用到它。

核心理解如下:
假设有两个函数: y = g ( x ) , z = f ( y ) y=g(x),z=f(y) y=g(x),z=f(y)
很显然: d y d x = g ′ ( x ) , d z d y = f ′ ( y ) \frac{dy}{dx}=g'(x),\frac{dz}{dy}=f'(y) dxdy=g(x),dydz=f(y)
为了求 d z d x \frac{dz}{dx} dxdz,我们就需要用到链式法则:
d z d x = d z d y ⋅ d y d x \frac{dz}{dx}=\frac{dz}{dy}·\frac{dy}{dx} dxdz=dydzdxdy

反向传播算法引例

看下面这个网络:

一共有3层:输入层、隐藏层、输出层。输入层有两个神经元 x 1 , x 2 x_1,x_2 x1,x2,隐藏层有两个神经元 h 1 , h 2 h_1,h_2 h1,h2(这里为了简化问题,方便我们理解算法,我们可以先不用考虑激活函数),输出有一个神经元 y y y.

为了直观的理解这个网络,我们可以为每个参数赋上具体的数值,比如,令 x 1 = 1 , x 2 = 2 x_1=1,x_2=2 x1=1,x2=2,假设各个权重的真实值分别是 w 1 , w 2 , w 3 , w 4 , w 5 , w 6 = 1 , 2 , 3 , 4 , 5 , 6 w_1,w_2,w_3,w_4,w_5,w_6=1,2,3,4,5,6 w1,w2,w3,w4,w5,w6=1,2,3,4,5,6,这样,通过正向运算,我们可以得到 y y y 的目标真实值 t a r g e t = 91 target=91 target=91

因此,在反向传播的过程中,我们只知道 x 1 = 1 , x 2 = 2 x_1=1,x_2=2 x1=1,x2=2以及 t a r g e t = 91 target=91 target=91 w 1 , w 2 , w 3 , w 4 , w 5 , w 6 w_1,w_2,w_3,w_4,w_5,w_6 w1,w2,w3,w4,w5,w6 是未知的,为了使反向传播进行下去,我们需要随机初始化各个权重值,假设我们随机初始化的结果是:
w 1 , w 2 , w 3 , w 4 , w 5 , w 6 = 0.5 , 1.0 , 1.5 , 2.0 , 2.5 , 3.0 w_1,w_2,w_3,w_4,w_5,w_6=0.5,1.0,1.5,2.0,2.5,3.0 w1,w2,w3,w4,w5,w6=0.5,1.0,1.5,2.0,2.5,3.0

下面就是反向传播算法的具体计算过程了:

首先,我们要计算误差,也就是真实值与预测值之间的误差 e = 1 2 ( t a r g e t − y ) 2 e=\frac{1}{2}(target-y)^2 e=21(targety)2。上面提到,真实值 t a r g e t = 91 target=91 target=91,需要我们计算预测值 y y y

h 1 = w 1 x 1 + w 2 x 2 = 2.5 h 2 = w 3 x 1 + w 4 x 2 = 5.5 y = w 5 h 1 + w 6 h 2 = 22.75 \begin{aligned} h_1 &= w_1x_1+w_2x_2=2.5\\ h_2 &= w_3x_1+w_4x_2=5.5\\ y &= w_5h_1+w_6h_2=22.75\\ \end{aligned} h1h2y=w1x1+w2x2=2.5=w3x1+w4x2=5.5=w5h1+w6h2=22.75
e = 1 2 ( t a r g e t − y ) 2 = 2329.03125 e=\frac{1}{2}(target-y)^2=2329.03125 e=21(targety)2=2329.03125

然后,我们就要开始更新各个权重值了,我们先来更新 w 6 w_6 w6 w 5 w_5 w5,为了更新 w 6 w_6 w6,需要计算 ∂ e ∂ w 6 \frac{\partial e}{\partial w_6} w6e,根据链式法则:
∂ e ∂ w 6 = ∂ e ∂ y ⋅ ∂ y ∂ w 6 \frac{\partial e}{\partial w_6}=\frac{\partial e}{\partial y}·\frac{\partial y}{\partial w_6} w6e=yew6y
已知 e = 1 2 ( t a r g e t − y ) 2 e=\frac{1}{2}(target-y)^2 e=21(targety)2,则
∂ e ∂ y = 2 ⋅ 1 2 ⋅ ( t a r g e t − y ) ⋅ ( − 1 ) = y − t a r g e t = 22.75 − 91 = − 68.25 \begin{aligned} \frac{\partial e}{\partial y}&=2·\frac{1}{2}·(target-y)·(-1)\\ &=y-target\\ &=22.75-91\\&=-68.25 \end{aligned} ye=221(targety)(1)=ytarget=22.7591=68.25
y = w 5 h 1 + w 6 h 2 y=w_5h_1+w_6h_2 y=w5h1+w6h2,则
∂ y ∂ w 6 = h 2 = 5.5 \frac{\partial y}{\partial w_6}=h_2=5.5 w6y=h2=5.5
因此,
∂ e ∂ w 6 = ∂ e ∂ y ⋅ ∂ y ∂ w 6 = − 375.375 \frac{\partial e}{\partial w_6}=\frac{\partial e}{\partial y}·\frac{\partial y}{\partial w_6}=-375.375 w6e=yew6y=375.375
最后,我们就可以运用梯度下降法来更新 w 6 w_6 w6,假设学习率 α = 0.005 \alpha=0.005 α=0.005,则
w 6 ′ = w 6 − α ⋅ ∂ e ∂ w 6 = 3.0 − 0.005 ∗ ( − 375.375 ) = 4.876875 \begin{aligned} w_6'&=w_6-\alpha·\frac{\partial e}{\partial w_6}\\&=3.0-0.005*(-375.375)\\&=4.876875 \end{aligned} w6=w6αw6e=3.00.005(375.375)=4.876875

同理,我们可以按照上面的方法更新 w 5 w_5 w5

上面已经求得 ∂ e ∂ y = − 68.25 \frac{\partial e}{\partial y}=-68.25 ye=68.25.
根据 y = w 5 h 1 + w 6 h 2 y=w_5h_1+w_6h_2 y=w5h1+w6h2
∂ y ∂ w 5 = h 1 = 2.5 \frac{\partial y}{\partial w_5}=h_1=2.5 w5y=h1=2.5
则,
w 5 ′ = w 5 − α ⋅ ∂ e ∂ w 5 = 2.5 − 0.005 ∗ ( − 68.25 ∗ 2.5 ) = 3.353125 \begin{aligned} w_5'&=w_5-\alpha·\frac{\partial e}{\partial w_5}\\&=2.5-0.005*(-68.25*2.5)\\&=3.353125 \end{aligned} w5=w5αw5e=2.50.005(68.252.5)=3.353125

下面,我们来更新 w 1 , w 2 , w 3 , w 4 w_1,w_2,w_3,w_4 w1,w2,w3,w4.

w 1 w_1 w1,根据链式法则:
∂ e ∂ w 1 = ∂ e ∂ y ⋅ ∂ y ∂ h 1 ⋅ ∂ h 1 ∂ w 1 \frac{\partial e}{\partial w_1}=\frac{\partial e}{\partial y}·\frac{\partial y}{\partial h_1}·\frac{\partial h_1}{\partial w_1} w1e=yeh1yw1h1
∂ e ∂ y \frac{\partial e}{\partial y} ye上面已经求过,根据 y = w 5 h 1 + w 6 h 2 y=w_5h_1+w_6h_2 y=w5h1+w6h2可以求得:
∂ y ∂ h 1 = w 5 = 2.5 \frac{\partial y}{\partial h_1}=w_5=2.5 h1y=w5=2.5
根据 h 1 = w 1 x 1 + w 2 x 2 h_1=w_1x_1+w_2x_2 h1=w1x1+w2x2 可以求得:
∂ h 1 ∂ w 1 = x 1 = 1 \frac{\partial h_1}{\partial w_1}=x_1=1 w1h1=x1=1
所以,
∂ e ∂ w 1 = ∂ e ∂ y ⋅ ∂ y ∂ h 1 ⋅ ∂ h 1 ∂ w 1 = ( − 68.25 ) ∗ 2.5 ∗ 1 = − 170.625 \begin{aligned} \frac{\partial e}{\partial w_1}&=\frac{\partial e}{\partial y}·\frac{\partial y}{\partial h_1}·\frac{\partial h_1}{\partial w_1}\\&=(-68.25)*2.5*1\\&=-170.625 \end{aligned} w1e=yeh1yw1h1=(68.25)2.51=170.625
因此,
w 1 ′ = w 1 − α ⋅ ∂ e ∂ w 1 = w 1 − α ⋅ ( y − t ) ⋅ w 5 ⋅ x 1 = 0.5 − 0.005 ∗ ( − 170.625 ) = 1.353125 \begin{aligned} w_1'&=w_1-\alpha·\frac{\partial e}{\partial w_1}\\&=w_1-\alpha·(y-t)·w_5·x_1\\&=0.5-0.005*(-170.625)\\&=1.353125 \end{aligned} w1=w1αw1e=w1α(yt)w5x1=0.50.005(170.625)=1.353125

同理,
w 2 ′ = w 2 − α ⋅ ∂ e ∂ w 2 = w 2 − α ⋅ ( y − t ) ⋅ w 5 ⋅ x 2 = 1 − 0.005 ∗ ( − 68.25 ) ∗ 2.5 ∗ 2 = 2.70625 \begin{aligned} w_2'&=w_2-\alpha·\frac{\partial e}{\partial w_2}\\&=w_2-\alpha·(y-t)·w_5·x_2\\&=1-0.005*(-68.25)*2.5*2\\&=2.70625 \end{aligned} w2=w2αw2e=w2α(yt)w5x2=10.005(68.25)2.52=2.70625
w 3 ′ = w 3 − α ⋅ ∂ e ∂ w 3 = w 3 − α ⋅ ( y − t ) ⋅ w 6 ⋅ x 1 = 1.5 − 0.005 ∗ ( − 68.25 ) ∗ 3 ∗ 1 = 2.52375 \begin{aligned} w_3'&=w_3-\alpha·\frac{\partial e}{\partial w_3}\\&=w_3-\alpha·(y-t)·w_6·x_1\\&=1.5-0.005*(-68.25)*3*1\\&=2.52375 \end{aligned} w3=w3αw3e=w3α(yt)w6x1=1.50.005(68.25)31=2.52375
w 4 ′ = w 4 − α ⋅ ∂ e ∂ w 4 = w 4 − α ⋅ ( y − t ) ⋅ w 6 ⋅ x 2 = 2 − 0.005 ∗ ( − 68.25 ) ∗ 3 ∗ 2 = 4.0475 \begin{aligned} w_4'&=w_4-\alpha·\frac{\partial e}{\partial w_4}\\&=w_4-\alpha·(y-t)·w_6·x_2\\&=2-0.005*(-68.25)*3*2\\&=4.0475 \end{aligned} w4=w4αw4e=w4α(yt)w6x2=20.005(68.25)32=4.0475

现在,我们已经更新了所有的权重值:
w 1 , w 2 , w 3 , w 4 , w 5 , w 6 = 1.353125 , 2.70625 , 2.52375 , 4.0475 , 3.353125 , 4.876875 w_1,w_2,w_3,w_4,w_5,w_6=1.353125,2.70625,2.52375,4.0475,3.353125,4.876875 w1,w2,w3,w4,w5,w6=1.353125,2.70625,2.52375,4.0475,3.353125,4.876875
再次根据前向运算可以求得 y = 74.47 y=74.47 y=74.47 e = 136.58 e=136.58 e=136.58,我们可以发现,新的预测值更接近真实值了,而且,误差变得更小了,就这样,只要重复上面的步骤,我们就可以得到越来越准确的预测模型了!!!!

这就是我对反向传播算法的理解,希望对大家有所帮助。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值