PyTorch里eval和no_grad的关系

首先这两者有着本质上区别。

model.eval()是用来告知model内的各个layer采取eval模式工作。这个操作主要是应对诸如dropoutbatchnorm这些在训练模式下需要采取不同操作的特殊layer。训练和测试的时候都可以开启。
torch.no_grad()则是告知自动求导引擎不要进行求导操作。这个操作的意义在于加速计算、节约内存。但是由于没有gradient,也就没有办法进行backward。所以只能在测试的时候开启。

所以在evaluate的时候,需要同时使用两者。

model = ...
dataset = ...
loss_fun = ...

# training
lr=0.001
model.train()
for x,y in dataset:
	model.zero_grad()
	p = model(x)
	l = loss_fun(p, y)
	l.backward()
	for p in model.parameters():
		p.data -= lr*p.grad
	
# evaluating
sum_loss = 0.0
model.eval()
with torch.no_grad():
	for x,y in dataset:
		p = model(x)
		l = loss_fun(p, y)
		sum_loss += l
print('total loss:', sum_loss)

另外no_grad还可以作为函数是修饰符来用,从而简化代码。

def train(model, dataset, loss_fun, lr=0.001):
	model.train()
	for x,y in dataset:
		model.zero_grad()
		p = model(x)
		l = loss_fun(p, y)
		l.backward()
		for p in model.parameters():
			p.data -= lr*p.grad
	
@torch.no_grad()
def test(model, dataset, loss_fun):
	sum_loss = 0.0
	model.eval()
	for x,y in dataset:
		p = model(x)
		l = loss_fun(p, y)
		sum_loss += l
	return sum_loss

# main block:
model = ...
dataset = ...
loss_fun = ...

# training
train()
# test
sum_loss = test()
print('total loss:', sum_loss)

参考:
https://pytorch.org/docs/stable/generated/torch.no_grad.html
https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值