训练一个模型,显存一直在不断增加....肯定是什么变量一直被保存在计算图里面。
经过排查发现,我从模型里返回了一个特征(因为这个特征后续要参与计算)。
但是feature没有参与loss计算,导致feature没有被反向传播到,一直保留在计算图中。
参考:
如何解决pytorch程序运行时内存消耗一直增加的问题? - 浮生号的回答 - 知乎 https://www.zhihu.com/question/276797963/answer/2355051638
解决方法:
在将feature传出模型前,先detach()