【Pytorch】optimizer原理

optimizer原理

【参考笔记】
【源码链接
举个栗子,定义一个全连接网络:

import torch
import torch.nn as nn
import torch.optim as optim

net = nn.Linear(2, 2)
# 权重矩阵初始化为1
nn.init.constant_(net.weight, val=100)
nn.init.constant_(net.bias, val=20)
optimizer = optim.SGD(net.parameters(), lr=0.01)

1. 测试optimizer有哪些属性

print(optimizer.__dict__)

得到:

{'defaults': {'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, 'state': defaultdict(<class 'dict'>, {}), 'param_groups': [{'params': [Parameter containing:
tensor([[100., 100.],
        [100., 100.]], requires_grad=True), Parameter containing:
tensor([20., 20.], requires_grad=True)], 'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]}

2. 测试optimizer的param_groups包含哪些参数

print(optimizer.param_groups)

得到:

[{'params': [Parameter containing:
tensor([[100., 100.],
        [100., 100.]], requires_grad=True), Parameter containing:
tensor([20., 20.], requires_grad=True)], 'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

其中2x2的矩阵是net的权重矩阵,1x2为偏置矩阵,其余为优化器的其它参数,所以说param_groups保存了优化器的全部数据,这个下面的state_dict()不同。

3. optimizer的状态 state_dict()

参考下面源码中对state_dict()的定义,可以看出state_dict()包含优化器状态state和参数param_groups两个参数

def state_dict(self):
    r"""Returns the state of the optimizer as a :class:`dict` """
    # Save ids instead of Tensors
    def pack_group(group):
        # 对"params"和其它的键采用不同规则
        packed = {k: v for k, v in group.items() if k != 'params'}
        # 这里并没有保存参数的值,而是保存参数的id
        packed['params'] = [id(p) for p in group['params']]
        return packed
    # 对self.param_groups进行遍历
    param_groups = [pack_group(g) for g in self.param_groups]
    # Remap state to use ids as keys
    packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
                    for k, v in self.state.items()}
    # 返回状态和参数组,其中参数组才是优化器的参数
    return {
        'state': packed_state,
        'param_groups': param_groups,
    }

打印优化器参数:

print(optimizer.state_dict()["param_groups"])

可以到优化器的完整参数如下:

[{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 
'nesterov': False, 'params': [2149749904224, 2149749906312]}]

打印优化器完整状态(状态+参数):

print(optimizer.state_dict())

可以到优化器的状态如下:

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2423968124216, 2423968124360]}]}

保存优化器的完整状态:

optimizer_old = optim.SGD(net.parameters(), lr=100) 
torch.save(optimizer_old.state_dict(), "optim_old.npy")

4. optimizer的load_state_dict()

恢复优化器的完整状态:

optimizer_new = optim.SGD(net.parameters(), lr=0.01)
old_state = torch.load("optim_old.npy")
# 将之前定义的优化器参数给新的优化器
optimizer_new.load_state_dict(old_state)
print(optimizer_new.state_dict()["param_groups"])

5. optimizer的梯度清空zero_grad()

optimizer.zero_grad()源码定义如下:

def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    # 获取每一组参数
    for group in self.param_groups:
        # 遍历当前参数组所有的params
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

这个遍历过程就是获取optimizer的param_groups属性的字典,之中的[“params”],之中的所有参数,通过遍历设定每个参数的梯度值为0。

6. optimizer的单步更新step()

直接看源码:

def step(self, closure):
    r"""Performs a single optimization step (parameter update).
    Arguments:
        closure (callable): A closure that reevaluates the model and
            returns the loss. Optional for most optimizers.
    """
    raise NotImplementedError

优化器的step()函数负责更新参数值,但是其具体实现对于不同的优化算法是不同的,所以optimizer类只是定义了这种行为,但是并没有给出具体实现。

【其他参考资料】

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值