深度学习测试的时候爆显存
我在训练的时候好好的,测试的时候输入变大一些,但是batchsize只有1,竟然爆显存了。
原因:在测试的时候,没有加上with torch.no_grad()
因为当模型在测试数据时,每次运行测试的代码,依旧会计算梯度得到新的特征图,所以显存占用逐步增多。实际上在测试的过程中,只需要网络计算输出结果,不需要网络计算梯度
除此之外,为了进一步释放测试或者训练过程中的显存,可以在代码中加入torch.cuda.empty_cache()
来释放掉显存中的中间变量,具体原理看Pytorch训练模型时如何释放GPU显存 torch.cuda.empty_cache()内存释放以及cuda的显存机制探索
例:
with torch.no_grad():
output = net(input)
torch.cuda.empty_cache()
参考