这两天在看CS231n的课程,不得不感慨我国目前的教育制度确实很难培养出真正高端的人才,别人的课程有那么多的精心准备的课件、笔记和供学生练习的任务和代码,让你从原理到工程的实现全过一遍,反观我所在的所谓的双一流大学的课程。。呵呵哒。。。不说闲话了,这里讲下反向传播中的链式法则的实现细节。
我们以这个公式为例,实现反向传播的链式法则,在神经网络中的反向传播的方式是一样的,只是函数比这个复杂些。
首先要把原函数拆解成与x,y有关的各部分的乘积与和,注意这里一定要拆成和和积的形式,不然的话求解参数更新值dx,dy的时候规则就会不统一。
拆解如下,sigma(x)就是1/(1+e^(-x))这个函数
x = 3 # example values
y = -4
# forward pass
sigy = 1.0 / (1 + math.exp(-y)) # sigmoid in numerator #(1)
num = x + sigy # numerator #(2)
sigx = 1.0 / (1 + math.exp(-x)) # sigmoid in denominator #(3)
xpy = x + y #(4)
xpysqr = xpy**2 #(5)
den = sigx + xpysqr # denominator #(6)
invden = 1.0 / den #(7)
f = num * invden # done! #(8)
然后我们对拆解的部分根据链式法则倒着推回去
#对 f = num * invden 求偏导
dnum = invden # gradient on numerator #(8)
dinvden = num #(8)
#这里分成对num和对invden各自求偏导两部分
#对 invden = 1.0 / den 求偏导
dden = (-1.0 / (den**2)) * dinvden #(7)
# 对den = sigx + xpysqr求偏导
dsigx = (1) * dden #(6)
dxpysqr = (1) * dden #(6)
# 对 xpysqr = xpy**2求偏导
dxpy = (2 * xpy) * dxpysqr #(5)
# 对xpy = x + y求偏导
dx = (1) * dxpy #(4)
dy = (1) * dxpy #(4)
# 对sigx = 1.0 / (1 + math.exp(-x))求偏导
dx += ((1 - sigx) * sigx) * dsigx # 注意这里是+= #(3)
# 对 num = x + sigy 求偏导
dx += (1) * dnum #(2)
dsigy = (1) * dnum #(2)
# 对sigy = 1.0 / (1 + math.exp(-y))求偏导
dy += ((1 - sigy) * sigy) * dsigy #(1)
# 完成
这里注意两点:
1:不要忘记乘上原函数的值,这是链式法则的基本规则
2:注意多元函数求偏导,按照这种拆解的方法进行时,要将各个部分对变量更新的贡献加起来,这就是为什么我之前说要按乘积与和的形式进行拆解的原因,否则如果拆解的时候出现差的形式((a)+(-b)变成了(a)-(b)),那么在计算参数值更新值的时候就会出现有的地方+=有的地方是-=的情况,很容易出错。