在训练时,明明前几个epoch都能正常跑,但是到了某一个epoch突然给我报显存不够了。我寻思你跑完epoch难道不释放显存的吗,为啥epoch和epoch之间的差距还这么大?
经过多方查询,多种方法的尝试,最终定位到loss的计算上。
原来我在反向传播完后会累加loss,以计算平均损失打印出来并写入tensorboard,而我在累加loss时用的是loss_epoch += loss ,此时loss会被放入计算图中一起保存,实际上除了反向传播以外,我的loss只是要那个数值而已。
所以解决办法就是在所有只需要数值的loss变量后面加上.item()
例如:loss_epoch += loss.item()
至此,问题解决。