前言
最近的工作中,用到了Pytorch框架训练医学图像分割模型。精心设计的模型经常会因为显存不足而失败。减小模型训练过程中对显存的占用,可能我们能想到最简单的方法就是减小batchsize,减少卷积核数量,裁剪输入图像的大小等。但是,以上方法可能会影响模型性能。经过多次尝试,总结了几种尽可能不改变模型结构,不影响模型性能且能够节省显存的训练方法。
清理GPU缓存
Pytorch提供了torch.cuda.empty_cache()函数用来清理GPU的缓存数据,在实际的模型训练中,我们可以将此函数与异常处理相结合,一旦发生显存溢出,异常处理机制可以捕获异常信息,清理GPU缓存,以保证模型正常训练。总体实现代码如下:
...
try:
output_data = model(input_data)
loss = calcloss(output_data,input_label)
...
except RuntimeError as exception:
if "out of memory" in str(exception):
print("GPU 显存不足!")
if hasattr(torch.cuda,"empty_cache"):
torch.cuda.empty_cache()
else:
raise exception
...
此方法能够起作用的前提是模型自身占用显存较小,如果模型过大,很容易丢失大量训练数据,影响模型性能。
改进反向传播中activation