torch.optim 之如何使用优化器optimizer

torch.optim模块包含了不同的优化器,支持大多数常用的优化算法,接口很通用。torch.optim创建一个优化器实体,保存当前model的状态,并且通过计算的梯度更新参数。创建时需要传给torch.optim一个包含model参数的迭代器,然后给该优化器指定learning rate、weight decaly等参数。(需要注意的:如果使用GPU,optimizers的创建需要model.cuda()之后)

1、SGD优化器函数原型

这里params可以是字典类型的,或是模型参数迭代器

函数原型
def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False)

2、param_groups的结构及作用

这里为什么要插入param_groups的结构及作用呢,是因为不同的初始化方式,param_groups的长度久不同。optimizer.param_groups是一个list结构,list中的元素是字典,字典的key是模块参数params,模块学习率lr,以及dampening、weight_decay、nesterov 6个元素,因此param_groups的结构是:

[
  {'params'     :  ,
  'lr'          :  ,
  'momentum'    :  , 
  'dampening'   :  , 
  'weight_decay':  , 
  'nesterov'    :  
  },
  
  {……}{……}]

param_groups里保存的优化器在不同模块上的参数,帮助你为不同的子网络设定不同的学习率,finetune时常使用该策略。

3、SGD优化器使用例子

#方法一、传递模型参数迭代器 
#lr=0.01, momentum=0.9 是默认参数                        
torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
len(optim.param_groups) = 1
[
  {'params'     :  ,
  'lr'          :  0.01,
  'momentum'    :  0.9, 
  'dampening'   :  0, 
  'weight_decay':  0.0001, 
  'nesterov'    :  False
  }
]



#方法二、传递一个字典
#lr=1e-2, momentum=0.9 是默认参数
# model.classifier模块的学习率指定学习率,'lr': 1e-3
# model.base模块的学习率是默认学习率,即lr=1e-2
optim.SGD([{'params': model.base.parameters()},
      {'params': model.classifier.parameters(), 'lr': 1e-3, 
      "momentum" :0.9, "weight_decay" :1e-4}], 
      lr=1e-2, momentum=0.9)

len(optim.param_groups) = 2

[
  {'params'     :  ,
  'lr'          :  1e-05,
  'momentum'    :  0, 
  'dampening'   :  0, 
  'weight_decay':  0, 
  'nesterov'    :  False
  },
  
  {'params'     :  ,
  'lr'          :  0.01,
  'momentum'    :  0.9, 
  'dampening'   :  0, 
  'weight_decay':  0.0001, 
  'nesterov'    False:  
  },
  
]
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值