Pytorch训练过程中显存不断增加原因之一
在使用pytorch利用测试集进行网络预测时,给网络输入数据,默认会构建计算图,构建计算图是为了方便后续的反向传播进行梯度计算,如果只是为了利用网络进行预测,则不需要构建完整的计算图。构建完整计算图会增加计算和累积内存消耗,导致所占GPU显存越来越大。
解决方案
在测试代码处于如下命令下:
with torch.no_grad():
例如:
with torch.no_grad():
prediction = net(images)
loss = loss_func(prediction , label) / batch_size