backward的 retain_variables=True与retain_graph=True 区别及应用

首先 :retain_variables=True 与 retain_graph=True   没有任何区别 

        只是 retain_variables 会在pytorch 新版本中被取消掉,将使用 retain_graph 。

        所以 在使用 retain_variables=True 时报错:

        TypeError: backward() got an unexpected keyword argument 'retain_variables'

 

应用:

两者都是backward()的参数   默认 retain_graph=False,

也就是反向传播之后这个计算图的内存会被释放,这样就没办法进行第二次反向传播了,

所以我们需要设置为 retain_graph=True,

loss_1.backward(retain_graph=True) 

 

retain_graph:

每次 backward() 时,默认会把整个计算图free掉。一般情况下是每次迭代,只需一次 forward() 和一次 backward() ,前向运算forward() 和反向传播backward()是成对存在的,一般一次backward()也是够用的。但是不排除,由于自定义loss等的复杂性,需要一次forward(),多个不同loss的backward()来累积同一个网络的grad,来更新参数。于是,若在当前backward()后,不执行forward() 而可以执行另一个backward(),需要在当前backward()时,指定保留计算图,即backward(retain_graph)。

 

 

参考链接:https://www.pytorchtutorial.com/pytorch-backward/     不过这篇写的是旧版本的了!

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页