第三节 backpropagation
L ( θ ) = ∑ n = 1 N C n ( θ ) L(\bm{\theta}) = \sum ^N _{n=1} C^n(\bm{\theta}) L(θ)=∑n=1NCn(θ)
C n ( θ ) C^n(\bm{\theta}) Cn(θ)是某一种loss值。上式中的上标n表示第n组数据,如 C n ( θ ) C^n(\bm{\theta}) Cn(θ)表示:
第 n 个 数 据 x 输 入 的 结 果 y 与 y ^ 之 间 的 距 离 第n个数据\bm x输入的结果y与\hat y之间的距离 第n个数据x输入的结果y与y^之间的距离。
上式若对w求偏导则为:
∂ L ( θ ) ∂ w = ∑ n = 1 N ∂ C n ( θ ) ∂ w \frac {\partial L(\bm{\theta})}{\partial w} = \sum ^N _{n=1} \frac {\partial C^n(\bm{\theta})}{\partial w} ∂w∂L(θ)=∑n=1N∂w∂Cn(θ)
此时只需考虑每个 C n ( θ ) C^n(\bm{\theta}) Cn(θ)的偏导即可。
求 ∂ C n ( θ ) ∂ w \frac {\partial C^n(\bm{\theta})}{\partial w} ∂w∂Cn(θ)可以转为求 ∂ C n ( θ ) ∂ z d z d w \frac {\partial C^n(\bm{\theta})}{\partial z} \frac {dz}{dw} ∂z∂Cn(θ)dwdz.其中 d z d w \frac {dz}{dw} dwdz即为前向传播。
其中的z是进入激活函数之前的,将本阶段之前的输入与权重相乘之后与偏移量相加得到的结果。并且z还需经过一系列复杂处理才能得到C。
Backpropagation-Forward Pass
为了计算 d z d w \frac {dz}{dw} dwdz,将z展开可以知道(假设只有两个特征,更多特征可在其后累加对应特征乘权重):
z = w 1 x 1 + w 2 x 2 + b z = w_1x_1 + w_2x_2+b z=w1x1+w2x2+b
则 ∂ z d w 1 = x 1 \frac {\partial z}{dw_1} = x_1 dw1∂z=x1, ∂ z d w 2 = x 2 \frac {\partial z}{dw_2} = x_2 dw2∂z=x2
可知对w求偏导的结果就是所偏导权重对应的输入。即w对应的input是什么,对应的偏导结果就是什么。
Backpropagation-Backward Pass
再用chain rule对z到c的过程拆解:
假设所用的激活函数为sigmoid function:
a = σ ( z ) a=\sigma(z) a=σ(z)
这个a就是下一个某个将要与其对应权重 w ′ w' w′相乘组合成 z ′ z' z′、 w ′ ′ w'' w′′相乘组合成 z ′ ′ z'' z′′…的输入x。
所以可以写成
∂ C ∂ z = ∂ C ∂ a ∂ a ∂ z \frac {\partial C}{\partial z} = \frac {\partial C}{\partial a} \frac {\partial a}{\partial z} ∂z∂C=∂a∂C∂z∂a
其中 ∂ a ∂ z \frac {\partial a}{\partial z} ∂z∂a就是对sigmoid函数求偏导。
而 ∂ C ∂ a = ∂ C ∂ z ′ ∂ z ′ ∂ a + ∂ C ∂ z ′ ′ ∂ z ′ ′ ∂ a \frac {\partial C}{\partial a} = \frac {\partial C}{\partial z'}\frac {\partial z'}{\partial a} + \frac {\partial C}{\partial z''}\frac {\partial z''}{\partial a} ∂a∂C=∂z′∂C∂a∂z′+∂z′′∂C∂a∂z′′ (chain rule)
如此可得到:
∂ C ∂ z = σ ′ ( z ) ∗ w 3 ∂ C ∂ z ′ + w 4 ∂ C ∂ z ′ ′ \frac {\partial C}{\partial z} = \sigma'(z) * w_3\frac {\partial C}{\partial z'} + w_4\frac {\partial C}{\partial z''} ∂z∂C=σ′(z)∗w3∂z′∂C+w4∂z′′∂C
其中 w 3 w 4 w_3 w_4 w3w4是由上面的Backward pass思路计算出来。