pytorch基础知识整理(五) 优化器

深度学习网络必须通过优化器进行训练。在pytorch中相关代码位于torch.optim模块中。

1, 常规用法

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for data,target in train_loader:
	...
	optimizer.zero_grad()
	output = model(data)
	loss = criterion(output, target)
	loss.backward()
	optimizer.step()  

2, optimizer的方法和属性

  • optimizer.zero_grad()的作用是清空模型中所有参数的梯度。

注:pytorch默认是进行梯度自动累加的,所以要使用optimizer.zero_grad()对梯度进行清零,如果遇到多个loss相加时,或者用多次循环再更新梯度法放大batch_size时,就不用optimizer.zero_grad()了。

  • optimizer.step()的作用是根据梯度更新参数,所以放在loss.backward()计算梯度之后。
  • state_dict()获得optimizer的状态字典,里面存放有param_groups和state。可通过torch.save()来保存到硬盘。
  • load_state_dict()用来加载保存的状态字典,以继续训练。
  • param_groups是一个list列表,存放各参数分组param_group。各参数分组param_group中包括学习率、momentum、weight_decay、dampening、nesterov以及各参数张量。
  • 以及add_param_group、defaults等

3,对不同的参数分组设置

对一个网络的不同层设置不同的学习率

optimizer = torch.optim.SGD([
        {'params': other_params}, 
        {'params': first_params, 'lr': 0.01*args.learning_rate},
        {'params': second_params, 'weight_decay': args.weight_decay}],
        lr=args.learning_rate,
        momentum=args.momentum,
)

对多个网络联合训练时,用一个optimizer控制所有网络:

optimizer = torch.optim.AdamW([{'params': Hnet.parameters(), 'lr': lr},
                               {'params': Pnet.parameters(), 'lr': lr},
                               {'params': Dnet.parameters(), 'lr': lr},])

4,各种优化器

torch.optim.
Adam(), SGD(), Adadelta(), Adagrad(), LBFGS(), AdamW(), SparseAdam(), Adamax(), ASGD(), RMSprop(), Rprop()
还有一些非pytorch自带的优化器,详见torch_optimizer(通过pip install torch-optimizer安装)。

5,lr_scheduler

根据epoch数或其他条件实现动态学习率,模板如下:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
     train(...)
     val_loss = validate(...)
     # Note that step should be called after validate()
     scheduler.step(val_loss)

学习率调整方法有:
ReduceLROnPlateau(),
LambdaLR(),
StepLR(),
MultiStepLR(),
CosineAnnealingLR(),
CosineAnnealingWarmRestarts(),
CyclicLR(),
OneCycleLR()

6,手动设计lr_scheduler

不使用pytorch自带的lr_scheduler,手动设计也不难:

for epoch in range(N):   
	if epoch > 10:
	    for param_group in optimizer.param_groups:
	        param_group['lr'] = lr/10
	train(epoch)
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值