这种情况需要检查一下代码有没有除了loss.backward()
之外的对loss进行过操作的地方。
一般有些地方会对loss进行叠加计算如:loss+=loss[i]
,这个写法是错误的,也是导致显存不断增加最终爆炸的原因。因为输出的loss的数据类型是Variable,PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。如果这里直接将loss加起来,系统会认为这里也是计算图的一部分。
所以计算的时候应该用```loss+=loss[i].item(),只取loss的数值进行计算。
记pytorch的大坑之训练的显存不断攀升
最新推荐文章于 2024-05-03 16:00:32 发布