为了节省显存,pytorch在反向传播的过程中只保留了计算图中的叶子结点的梯度值,而未保留中间节点的梯度,但是在给复杂的网络做了一些改变,比如从某个中间变量出发,增加了一种新的loss作为总loss的一部分,需要确认总loss在进行BP时会沿着新增的loss计算的过程这一路径回传到中间变量,因此需要对中间变量的梯度值进行检查。在自动BP的情况下,可以通过register_hook的方式来实现:
yinbianliang = function(zibianliang) #自变量到因变量之间的计算过程,其中zibianliang不是叶子节点
def extract(g):
global features_grad
features_grad = g
zibianliang.register_hook(extract) #设置钩子
yinbianliang.backward() #执行BP,钩子勾到数据,执行extract函数
######################
d_yinbian_div_d_zibian = features_grad
d_yinbian_div_d_zibian 就是获得的因变量对自变量的梯度值