动态计算图是在程序前向传播的过程中构建起来的,主要是用来进行反向传播。相比搭建网络结构时关注每一层的计算方式,计算图主要视角是数据节点(Tensor)。
在计算图构建和反向传播过程中存在一些令人混淆的概念,例如is_leaf、requires_grad、detach()、zero_grad()、 retain_grad()、torch.nograd()。从计算图反向传播的角度去理解这些概念,一切就变的清晰了。
动态图中的反向传播
图1 动态计算图
上图是计算图的示意图:X1和X2是两组输入数据Tensor,P1和P2是网络的权重Tensor,Y和Z是计算的中间结果Tensor,Fn是得到中间结果的计算操作。
-
训练网络的最终目的是更新P1和P2的值,因此需要计算loss关于P1和P2的梯度,为了得到关于P1和P2的梯度,需要依次计算loss关于中间结果Y和Z的梯度。
-
权重Tensor的梯度计算不依赖输入数据且输入数据X不需要更新数值,所以X1和X2不需要计算梯度。
因此默认情况下用户创建的输入数据requirs_grad=False, 网络的权重参数requirs_grad=True
1. 叶子节点
-
由用户的上帝之手直接创造出的Tensor为叶子节点, 这些节点没有记录grad_fn参数(例如输入数据网络权重)。
-
由需要梯度计算的叶子节点通过运算衍生出来的Tensor为非叶子节点,这些节点有grad_fn参数。
叶子节点是处在计算图外围或游离于计算图之外的节点,它们是反向传播的末端。
例如下面代码中 X和XX,P1,P2(图2中绿色点)都为叶子节点,Y和loss(图中白色点)都为非叶子节点。
p1=torch.tensor([1.0,2.0,3.0],requires_grad=True)
p2=torch.tensor([1.1,2.2,3.3],requires_grad=True)
x=torch.tensor([4.0,5.0,6.0],requires_grad=False)
xx=x**2
y1=p+xx
y1=y1.detach()
y2=torch.sigmoid(y1) y3=y2+p2 y4=torch.sigmoid(y3) loss=y4.mean()
图2 叶子节点示意图
2. requires_grad
这个属性说明当前的Tensor需要计算来自loss的梯度,由于叶子节点P1和P2需要计算梯度更新参数,所有由他们衍生出的到loss的通路(中间结果)都需要计算梯度(requires_grad==True)。
例如下图中的橙色节点。
图3 需要计算梯度节点的示意图
pytorch中无法将非叶子节点requires_grad设置为False,毕竟将上图中的某个Y终止梯度计算显得很奇怪也不优雅。 若想停止对P1计算梯度,可以直接将P1的requires_grad设置为False,这样后续的Y1,Y2的requires_grad同样也为False。
3. detach()
若此时P1和P2想用两个loss更新怎么办,例如生成对抗网络,此时可以将Y2进行detach(),如下图:
图4 detach示意图
detach()返回的是一个副本,这个副本成为一个叶子节点跟原来的计算图断开了联系,并可以构造出一个新的计算图。这样通过两个loss可以分别计算P1,P2的梯度互不干扰。
4. zero_grad()
若我们给网络喂两次数据,如下图5所示,计算图又多了一条分支。
图5 两次前向传播示意图
- 若此时我们想用两次的loss的梯度更新网络,则对loss1和loss2分别进行反向传播,P1和P2会对两次的梯度进行累加。相当于增加了batch的大小(loss1+loss2),或者实现孪生网络结构的更新。
loss1.back_ward()
loss2.back_ward()
optimizer.step()
#等价于
(loss1+loss2).back_ward()
-
若我们想用两次loss分别对权重参数进行更新,则需要在每次权重参数更新后,将网络的权重参数梯度置零;否则会累加到下一次loss的梯度中,影响权重更新。
loss1.back_ward() //计算loss1的梯度 optimizer.step() //根据loss1梯度更新权重
optimizer.zero_grad() //将网络的权重参数置零 loss2.back_ward() //计算loss2的梯度
optimizer.step() //根据loss2梯度更新权重
5. retain_grad()
计算梯度的最终目的是给优化器去更新权重参数,即只有P1和P2的梯度信息是用到的,其他非叶子节点的梯度信息只起到向后传播的作用,用完就扔没必要保留。
图6 计算梯度的节点
为了在反向传播结束后依旧能保留该节点的梯度信息,需要调用retain_grad()来阻止释放。
6. torch.nograd()
一个上下文管理器,在上下文计算过程中产生的Tensor都不会进行自动求导,相当于从计算图中剥离。用于在评估阶段减少反向传播相关的额外计算和内存消耗。
将所有网络权重参数requires_grad设置False也能达到相同的效果,但是torch.nograd()更简洁优雅。
6. Optimizer
优化器根据权重的参数的梯度更新权重,在构造优化器时需要将需要将用其进行更新的权重参数参数注册一下。优化器就可以对这些权重参数进行梯度清零和权重更新操作。
以上是在使用pytorch中的一些感悟和总结,如有纰漏还望指正。