使用pytorch0.4框架,在train的时候不会报错,在validation时候报错out of memory错误。
解决方法:
在train过程前需要设置如下:
torch.set_grad_enabled(True)
# switch to train mode
model.train()
在validation过程前需要设置如下:
# In PyTorch 0.4, "volatile=True" is deprecated.
torch.set_grad_enabled(False)
# switch to evaluate mode
model.eval()