RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
报错信息为一个inpalce操作和梯度冲突了
1.什么是inplace操作?
inplace操作即用该变量更新该变量自己的值,形如下面的
x+=1
b=x.exp_()(x.exp_()inplace操作了)
2.错误产生的原因
如下面代码对结果backward()就会报RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation错误
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True
d = torch.matmul(x, w1)
f = torch.matmul(d, w2)
d[:] = 1 # 因为这句, 代码报错了 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
f.backward()
错误产生的原因为:
3.如何定位到该错误发生的行?
(1)如果你代码较短,可以找有以下形式的inplace操作的形式
如 x+=1; inpalce=true,x[1]=x[1]+1等
(2)该错误一般是在loss.backward()的时候报,并不会给出错误发生具体的行,可以见loss.backward()写在with torch.autograd.set_detect_anomaly(True):下面,参考github
import torch
with torch.autograd.set_detect_anomaly(True):
a = torch.rand(1, requires_grad=True)
c = torch.rand(1, requires_grad=True)
b = a ** 2 * c ** 2
b += 1
b *= c + a
d = b.exp_()
d *= 5
b.backward()
这样可以帮你快速定位到inplace操作所在的行
(3)如果你模型比较大,代码比较多,而且代码中的inplace操作是隐式的,不好一眼看错哪行代码inplace操作了,那么将每个模型单独拿出来loss.backward(),出现RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation错误的那个模型的代码就是有问题的。
(4)最终一定能定位到inplace操作的方法
一般来说该错误发生在模型的forward()函数里,那么你可以将forward函数里面的每个变量单独拿出来backward(),
如
def forward(self, feature, sents):
w_enb = self.w_emb(sents)
# w_enb.backward(torch.ones(2, 14, 600).to(device))# 没问题
context = self.q_emb.forward(w_enb)
context.backward(torch.ones(2, 14, 1024).to(device)) # 没问题
# context.backward(torch.ones(2, 14, 1024).to(device))这行没问题
# 说明这行前面的变量和模型都没有inplace操作
# 如果context.backward(torch.ones(2, 14, 1024).to(device))出现
# RuntimeError: one of the variables needed for gradient computation
# has been modified by an inplace operation,那就说明context这个变量就是那个有问题的变量
# 就是那个inplace操作了的变量
4.如何解决这个inpalce操作带来的问题
解决方案1:把所有的inplace=True改成inplace=False
解决方案2:将out+=residual这样所有的+=操作,改成out=out+residual
解决方案3:对于output1[:, 0, :, :, :] = 1 - output1[:, 0, :, :, :]这种操作,正确的写法是
output2 = output1.clone()
output2[:, 0, :, :, :] = 1 - output1[:, 0, :, :, :]
要先创建一个有独立内存空间的中间变量