最近需要学习pytorch,对backward的底层逻辑十分感兴趣,把个人的学习笔记记录在此。
官方文档描述
在pytorch官网中,backward的简介描述如下:
Computes the sum of gradients of given tensors with respect to graph leaves.
也就是用图来计算给定张量的梯度和。
具体描述如下:
The graph is differentiated using the chain rule. If any of tensors are non-scalar (i.e. their data has more than one element) and require gradient, then the Jacobian-vector product would be computed, in this case the function additionally requires specifying grad_tensors. It should be a sequence of matching length, that contains the “vector” in the Jacobian-vector product, usually the gradient of the differentiated function w.r.t. corresponding tensors (None is an acceptable value for all tensors that don’t need gradient tensors).
翻译过来就是,反向传播图是用链式法则求微分的。对于非标量的且需要求梯度的tensor,会计算其Jacobian-vector。同时,如果grad_tensors参数为True,则应当是一个包含Jacobian向量积中的vector的匹配序列,即是一个对于此张量的微分函数的梯度。
qaq好像看不太懂,不过没关系,可以细讲。
反向传播
首先,什么是反向传播?
我们知道,函数求导是存在链式法则的。也就是说
d
f
d
x
=
d
f
d
y
d
y
d
x
=
d
f
d
z
d
z
d
y
d
y
d
x
=
.
.
.
\frac{df}{dx}=\frac{df}{dy}\frac{dy}{dx}=\frac{df}{dz}\frac{dz}{dy}\frac{dy}{dx}=...
dxdf=dydfdxdy=dzdfdydzdxdy=...可以不断一直接下去。
同样地,函数求偏导也有类似的法则。详情可以看cs231n的tutorial。
简而言之,通过反向求导,可以把下一级函数相对于上一级函数的梯度求出,并不断传递下去。
# set some inputs
x = -2; y = 5; z = -4
# perform the forward pass
q = x + y # q becomes 3
f = q * z # f becomes -12
# perform the backward pass (backpropagation) in reverse order:
# first backprop through f = q * z
dfdz = q # df/dz = q, so gradient on z becomes 3
dfdq = z # df/dq = z, so gradient on q becomes -4
dqdx = 1.0
dqdy = 1.0
# now backprop through q = x + y
dfdx = dfdq * dqdx # The multiplication here is the chain rule!
dfdy = dfdq * dqdy
这段函数,分析而言,dfdz=q=x+y=3,dfdq=z=-4
则dfdx=-4,dfdy=-4,这样就求出了f相对于x,y,q,z的梯度,这就是一个反向传播,求导的过程。
神经网络的优化方向是找到损失最小的点,对于损失函数的优化而言,就是依据梯度下降法,不断向着梯度减小的方向前行,而反向传播可以通过层层传播,来求出各层的梯度,有效提高了运算速度。具体愿意可以看这篇blog。
构建反向传播图
在pytorch中,主要利用反向传播图来实现backward()。
上图是从blog中引用而来的pytorch反向传播示意图,pytorch的variable是一个存放会变化值的地理位置,里面的值会不停变化,像装糖果(糖果就是数据,即tensor)的盒子,糖果的数量不断变化。pytorch都是由tensor计算的,而tensor里面的参数是variable形式。
那么pytorch究竟是如何构建反向传播图的呢?
此博客以及此博客中对于代码的解释比较复杂但非常清楚。