0. 问题描述
register_hook
用于给某个tensor注册hooks,- 这个函数需要传入一个钩子函数,而且钩子函数的input是
loss.backward()
执行后的grad
(不能获取weight值)
- 这个函数需要传入一个钩子函数,而且钩子函数的input是
- 笔者这个时候loss不收敛,debug发现梯度为0,因此通过加钩子,试图发现在传播时哪里出了问题。因此发现了
register_hook
并不是100%能work,而且不是100%的可以打印出grad
1. 探究过程
- 钩子函数:
def save_grad(name):
print("****") # 这行可以说明有没有执行钩子函数
def hook(grad):
print(f"name={name}, grad={grad}")
return hook
- 注册过程
# U_head是loss function中的一个中间tensor,需要计算梯度
U_head.register_hook(save_grad("U_head"))
2. 不work的原因
-
来自Stack Overflow的建议:
- ‘register_hook’ won’t only in two cases:
- It was registered on a Tensor for which the gradients was never computed./ 梯度没计算
- The register_hook function is some part of your code that did not run during the forward. / 正向传播没执行这条语句
- ‘register_hook’ won’t only in two cases:
-
因此结合了自己的代码,修正了如下几个bug,然后就work了
- 最关键的修改:
for t in idx: # work了 U_tail= torch.cat([U_head, f_Equi_ts(t, T,v, d, b, alpha, labda,device)]) # 这样写不work,莫非属于原地修改? # U_tail = torch.tensor([f_Equi_ts(t, T,v, d, b, alpha, labda,device) for t in idx],requires_grad=True)
- 对于 需要loss function的代码重写,尽可能简单,易读性高,先不追求效率