是否你也是自己写损失函数遇到了问题,推荐阅读Pytorch自动求导的细节
下面是自己的一点笔记,详细内容还是阅读推荐的文章,写的很棒!
在求导过程中,只有设置了可导的叶子节点的导数会保存下来,非叶子结点的导数在用完就释放掉了。
查看一个节点w1是否是叶子节点以及是否需要求导的程序:
print(w1.is_leaf, w1.requires_grad)
如果想在最后可以打印某个非叶子结点(如 l1)的导数值,则在loss.backward()之前加上下面这句
l1.retain_grad()
如果对叶子节点做了in-place操作,那么叶子节点就会变成非叶子结点,从而在求导的过程中不会保存导数值,而使得loss.backward()发生错误。
如果想修改某个叶子节点的值,可通过下面的方式:
a = torch.tensor([10., 5., 2., 3.], requires_grad=True)
#--1---
a.detach().fill_(10.)
#--2---
with torch.no_grad():
a[:] = 10.
loss = (a*a).mean()
loss.backward()