解决pytorch反向传播过程中出现 RuntimeError: Trying to backward through the graph a second time 问题

最近写代码遇到这个问题,先展示下问题,整个问题的代码放在文章最后。

 比较常见的问题,通常都是第一次迭代没问题,第二次迭代就出现这个错误,以下展示我的部分代码,并且从头到尾分析以下我的解决方案。

可以一边看我的代码一边看问题,这里只是示范代码,所以并没有写的很复杂。

问题的出现在于我在循环里使用了如下的更新z1的公式,

 

 z1是通过  z1 = cal(t0,z0,t1,z1,z2) 得到的,这里z1是已经定义了的变量,但是经过这个公式又把z1更新覆盖了,虽然说这个错误可以在最后loss.backward()函数里加上retain_graph=Ture,即loss.backward(retain_graph=Ture)解决,但是这会增加显存的使用,明显当任务需要较大的计算量时,这是不合理的,所以这种方案并不是最优的。

而我的解决方案是对 z1 = cal(t0,z0,t1,z1,z2) 进行了

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值