pytorch训练神经网络爆内存解决办法
训练的时候内存一直在增加,最后内存爆满,被迫中断。
后来换了一个电脑发现还是这样,考虑是代码的问题。
检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。问题就在loss,因为loss是variable类型。
要写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss。否则就会发现随着epoch的增加,占的内存也在一点一点增加。
算是一个小坑吧,希望大家还是要仔细。