pytorch load state dict_pytorch源码阅读(二)optimizer原理

v2-55ef33e52e09751b25350b6f65edc0da_1440w.jpg?source=172ae18b

pytorch包含多种优化算法用于网络参数的更新,比如常用的SGD、Adam、LBFGS以及RMSProp等。使用中可以发现各种优化算法的使用方式几乎相同,是因为父类optimizer【1】定义了各个子类(即SGD等)的核心行为,下面是optimizer类注释:

class 

其中首句“所有优化器的基类” 表明所有的优化器都必须继承optimizer类,下面来分析optimizer类的的各个实例函数。

1、初始化__init__()

def 

优化器需要保存学习率等参数的值,所以optimizer类需要用实例属性来存储这些参数,也就是__init__()中的self.param_groups,下面的代码通过一个全连接网络来测试优化器的param_groups包含哪些参数:

net 

得到:

[{

其中2x2的矩阵是net的权重矩阵,1x2为偏置矩阵,其余为优化器的其它参数,所以说param_groups保存了优化器的全部数据,这个下面的state_dict()不同。

2、优化器状态state_dict()

def 

查看上一节定义的optimizer的state_dict():

print

可以到优化器的完整参数如下:

[{

3、优化器参数加载load_state_dict()

上一节中的state_dict()负责提取优化器的参数,可以保存到本地用于下次训练恢复使用,对应的必然有load_state_dict()用于优化器参数的加载,其源码如下:

def 

为了测试state_dict()和load_state_dict(),可以首先存储一个学习率为100的优化器的参数到本地:

optimizer_old 

现在这个优化器的参数已经存储到本地,然后将这个优化器参数重新加载给一个新的学习率为0.01优化器:

optimizer_new 

得到new优化器的学习率不是0.01,而是old优化器的学习率100:

[{

4、梯度清空zero_grad()

在网络优化过程中optimizer.zero_grad()函数需要被显式调用,负责清空其关联网络的参数梯度值,其源码如下:

def 

这个遍历过程就是获取optimizer的param_groups属性的字典,之中的["params"],之中的所有参数,通过遍历设定每个参数的梯度值为0。

5、单步更新step()

def 

优化器的step()函数负责更新参数值,但是其具体实现对于不同的优化算法是不同的,所以optimizer类只是定义了这种行为,但是并没有给出具体实现。

6、总结

优化算法部分的代码并不多,但是不同的优化算法涉及的概念较多,看懂各种算法的实现需要很强的数学功底。optimizer类定义了各种优化算法的公共行为与抽象方法,是典型的面向对象的继承思想。

参考:

【1】https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值