torch.optim

本文详细介绍了PyTorch库中的四个常用优化器:SGD、Adagrad、RMSprop和Adam,涵盖了它们的基本用法、参数解读及各自特点。通过实例演示了如何在模型训练中应用这些优化器,并讨论了Adagrad的相对较少使用。
摘要由CSDN通过智能技术生成

optimizer优化器一节中讲解了四个优化器,分别是GD, SGD, SGDM, Adagrad, RMSProp, Adam,在PyTorchtorch.optim中包含了后五个,这里讲解这五个优化器的PyTorch使用方法

一、torch.optim.SGD

torch.optim.SGD包含了SGD以及SGDM

torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

  • params:待优化的参数,也就是model中的parameteres
  • lr:学习率,需要给定
  • momentum:可选项,不选择为SGD,选择为SGDM
  • dampening:动量抑制因子,与SGD无直接关系
  • weight_decay:参数L2惩罚,与SGD无直接关系
  • nesterov:是否使用nesterov动量,默认为False

首先我们看一下PyTorch官方示例然后讲解各个参数的含义

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

第一步输入构建SGD优化器,第二步将上一步中优化器的梯度清0(这一步是必须的,因为在每一次循环时梯度会进行累加,也就是说如果不清0的话这次所求的梯度是上次与这次的累加),第三步进行反向传播,第四步根据梯度来更新参数

其中params=model.parameters()表示优化模型的参数,lr=0.1表示学习率为0.1,momentum=0.9表示采用SGDM优化算法,还记得之前我们讲过的SGDM先求出移动步长,在进行更新,其实还有另一种更新,这两种更新思路一样,但是就是学习率的位置不同,如下图所示。这种更新就是设置nesterov=True,教程中也给出了两种更新的区别,不过一般来说我们都采取之前讲过的更新方法,也就是下图的第一种,注意其中的 μ \mu μ就是我们的momentum因子

在这里插入图片描述

还有两个参数,一个是 weight_decay,另一个是dampening,第一个是给权重加入L2惩罚,之后来介绍这一点。第二个是动量抑制因子,这个好像没怎么用到,就不作介绍。

二、torch.optim.Adagrad

torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)

  • params:待优化的参数,也就是model中的parameteres
  • lr:学习率,默认为0.01
  • lr_decay:可选项,默认为0,与Adagrad无直接关系
  • weight_decay:参数L2惩罚,与Adagrad无直接关系
  • eps:提高数值稳定性,避免分母为0

这里就不详细叙述了,因为感觉用这个优化的很少,而且我将SGD换成Adagrad之后感觉效果还不好(可能是学习率直接降的非常非常低),而且使用之后感觉lr_decay参数没什么用

三、torch.optim.RMSprop

torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)

  • params:待优化的参数,也就是model中的parameteres
  • lr:学习率,默认为0.01
  • alpha:平滑常数(衰减速率),默认为0.99
  • eps:提高数值稳定性,避免分母为1e-8
  • weight_decay:参数L2惩罚,与RMSprop无直接关系
  • momentum:0,与RMSprop无直接关系
  • centered:False,与RMSprop无直接关系

四、torch.optim.Adam

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

  • params:待优化的参数,也就是model中的parameteres
  • lr:学习率,默认为0.001
  • betas:beta的值
  • eps:提高数值稳定性,避免分母为1e-8
  • weight_decay:参数L2惩罚,与Adam无直接关系
  • amsgrad:是否使用Adam的变体AMSGrad,False
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值