虽然使用net.eval()调到了验证阶段,但是还是要使用
with torch.no_grad():
outputs = Net_(inputs)
来取消验证阶段的loss。
个人猜测出现此原因是由于梯度在验证阶段不回传,造成梯度的累计。
我做实验发现,在验证阶段,第一个batch不会报错,第二个batch就报错。
虽然使用net.eval()调到了验证阶段,但是还是要使用
with torch.no_grad():
outputs = Net_(inputs)
来取消验证阶段的loss。
个人猜测出现此原因是由于梯度在验证阶段不回传,造成梯度的累计。
我做实验发现,在验证阶段,第一个batch不会报错,第二个batch就报错。