one of the variables needed for gradient computation has been modified by an inplace operation

本文主要讨论pytorch在训练时出现错误:Runtime Error:one of the variables needed for gradient computation has been modified by an inplace operation 的不同情况。

对需要求导的Tensor使用了 inplace 操作

inplace操作类似于a+=5即使用变量参与了运算,又修改了变量的值。
这类问题在这里讲解的比较清楚。
比如一个张量a用于计算张量b,随后又对a重新赋值,但是使用b求解loss并反向传播,会导致这一问题。

1.5以上版本pytorch训练GAN

使用1.5以上版本的pytorch训练GAN的时候也会出现这样的问题,但这个时候检查代码会发现模型结构可能并没有inplace的操作。这是由于版本更新了多个模型优化器step()方法。
这一问题在这里可找到比较详细的解释。
在1.4及以前的pytorch,我们训练GAN通常遵循以下步骤:

fake = G(input) #generate fake
outFake = D(fake) #try to fool discriminator
outReal = D(real)
lossD = - torch.mean(torch.log(outReal ) + torch.log(1. - outFake))
lossG = torch.mean(torch.log(1. - outFake))

optD.zero_grad()
lossD.backward(retain_graph=True)
optD.step()

optG.zero_grad()
lossG.backward()
optG.step()

也就是先通过生成器G生成假的输出fake,分别输入判别器D得到结果,计算G和D的loss,分别反向传播。由于backward两次,且两个网络是“连通的”(通过fake张量)所以第一次backward时加上retain_graph=True即可。
这个过程在1.4以及之前的版本是没有问题的。

但是换到1.5以上,就会出现本文所描述的错误,通常还会在此之前先报例如“cudnnConvBackward”或其他网络层的反向传播错误,随后traceback到这个错误。

原因在于,D的优化器在第一次lossD反向传播时 inplace的修改了discriminator的参数,而优化G有需要用到D,导致了这一错误。
解决方案是,我们需要人为分开D和G使用的两个loss,例如

fake = G(input) #generate fake
outFake = D(fake) #try to fool discriminator

lossG = torch.mean(torch.log(1. - outFake))
optG.zero_grad()
lossG.backward()
optG.step()

outReal = D(real)
outFake = D(fake.detach())
lossD = - torch.mean(torch.log(outReal ) + torch.log(1. - outFake))
optD.zero_grad()
lossD.backward()
optD.step()

先生成fake,通过D判别并计算G的loss,反向传播修改G的参数。
接下来再次将fake和real传入D中,此时fake加上detach,将已经修改过的G分离开,并计算D的loss从而优化D。

这一问题是1.5以后才出现的,主要是由于GAN的两部分网络通过生成的张量连接了起来,两个网络分别优化反向传播时,有一个会先被inplace的修改掉,从而产生错误。而我们知道,GAN本身两个网络是要分别修改交替优化的。所以在适当的位置detach分离可以让两个网络在一次迭代中正常的优化。

上述的解答是官方在对应issue中提到的解决方案,个人测试后发现确实可以避免这样的错误。

  • 9
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值