@[原创]关于解决pytorch训练神经网络时显存一直增长的问题
问题描述
在训练自定义loss和自定义网络结构的一个模型的时候,发现模型和数据都比较简单的情况下,在训练过程中居然把24g的显卡拉爆了显存。
然后使用nvidia-smi -l观察显存变化,发现是有规律的显存一直增加,直到OOM。
问题解决思路
在这个过程中尝试询问了chatgpt,但是发现它提供的解决方案,诸如torch.cuda.memory_cached()/del data 等命令放在每次循环后面并不能解决问题。
所以后面尝试在谷歌进行搜索,找到了下面这篇的知乎的博客:
链接: link.
这篇文章的四种方法其实都没有解决我的问题,但是它的第一种情况给了我一点启发,此外chatgpt在最开始提到的原因也是关键,促成了后面问题的解决。
在直接尝试各种解决手段无果后,我决定自行去查看代码段和显存的使用情况,主要是使用下面这个命令:
print("Memory Allocated:", torch.cuda