转自简书:解决 pytorch 在训练时由于设置了验证集导致 out of memory (同样可用于测试时减少显存占用)
问题描述
在跑pytorch的时候,在训练阶段监控到显存占用2.7G左右,但到了验证阶段发现占用了3.65G左右,对于我4G显存的显卡来说很容易爆掉。
解决方法
假设一开始训练和验证阶段如下:
# 训练
for i, (train_data, train_label) in enumerate(train_loader, 1):
training
# 验证
for i, (vali_data, vali_label) in enumerate(vali_loader, 1):
validating
改成
# 训练
for i, (train_data, train_label) in enumerate(train_loader, 1):
training
# 验证
with torch.no_grad():
for i, (vali_data, vali_label) in enumerate(vali_loader, 1):
validating
之后训练和验证阶段显存占用就几乎没变了
后来查了下资料:
with torch.zero_grad():
主要用于停止 autograd 模块的工作,以起到加速和节省显存的作用,具体行为就是停止 gradient 计算,从而节省了 GPU 算力和显存。