训练时
在训练时,检测 “out of memory” 的error并通过torch.cuda.empty_cache()处理
如:
try:
outputs = net(inputs)
except RuntimeError as exception:
if "out of memory" in str(exception):
print('WARNING: out of memory, will pass this')
torch.cuda.empty_cache()
continue
else:
raise exception
测试时
在测试时,避免忘记设置 torch.no_grad()
如:
with torch.no_grad():
inputs = None
outputs = model(inputs)