PyTorch 源码分析:Optimizer类

PyTorch对Optimizer类的实现大部分都在Python上,只有计算用到了C++的部分,所以还是可以继续分析的。

总览

Optimizer类是所有具体优化器类的一个基类。下面一幅图表示一下。

这里我以SGD类为例自下而上地介绍一下。

Optimizer类中重要的成员变量只有两个,self.param_groups和self.state。

self.param_groups用于存储模型参数和优化器本身的一些参数(如学习率等)。

self.state则用于存储更新过程中模型参数对应的各种临时状态,如MSGD中每个参数需要对应一个动量。而每个参数可能不止需要对应一个临时状态。因此self.state是一个键值对类型为parameter:dict的有序字典。

Optimizer类中重要的方法只有一个 add_param_group,它是用来初始化self.param_groups的。

而self.state的初始化需要在某个具体的优化器类中进行。

self.param_groups如何初始化?

self.param_groups在optimizer类的__init__方法中初始化完成。

这里可以先看一下SGD类的初始化方法,它将lr,momentum等优化器参数打包成字典defaults,然后和模型参数params一起传入optimizer类的初始化方法中。

class SGD(Optimizer):
    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        #打包优化器参数
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

而在Optimizer类的初始化方法中,对于defaults,它只是将defaults存储起来。对于params则是先转换成列表形式,之后转换成一个由列表封装的字典。然后对这个字典执行self.add_param_group。

至此我们还是没有看到self.param_groups到底是怎么初始化的,所以需要继续看self.add_param_group这个方法。

注意区分这里的self.param_groups和param_groups。

class Optimizer(object):
    def __init__(self, params, defaults):
        torch._C._log_api_usage_once("python.optimizer")
        self.defaults = defaults

        self._hook_for_profile()

        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))
        #self.state初始化
        self.state = defaultdict(dict)
        self.param_groups = []
        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        #一般情况下param_groups[0]是一个parameters类
        #这里其实是在判断param_groups之前有没有被封装过。
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]
        #虽然是遍历操作,但是其实并不是遍历所有参数。
        for param_group in param_groups:
            #这里的param_group等价于{'params': param_groups}
            self.add_param_group(param_group)

add_param_group源码:

def add_param_group(self, param_group):
    r"""Add a param group to the :class:`Optimizer` s `param_groups`.

    This can be useful when fine tuning a pre-trained network as frozen layers can be made
    trainable and added to the :class:`Optimizer` as training progresses.

    Args:
        param_group (dict): Specifies what Tensors should be optimized along with group
        specific optimization options.
    """
    #前面都不重要,都是一些边界条件的判断
    assert isinstance(param_group, dict), "param group must be a dict"
    params = param_group['params']
    if isinstance(params, torch.Tensor):
        param_group['params'] = [params]
    elif isinstance(params, set):
        raise TypeError('optimizer parameters need to be organized in ordered collections, but '
                        'the ordering of tensors in sets will change between runs. Please use a list instead.')
    else:
        param_group['params'] = list(params)

    for param in param_group['params']:
        if not isinstance(param, torch.Tensor):
            raise TypeError("optimizer can only optimize Tensors, "
                            "but one of the params is " + torch.typename(param))
        if not param.is_leaf:
            raise ValueError("can't optimize a non-leaf Tensor")
    #这里开始就是self.param_groups的初始化了
    #defaults在这里加入param_group
    for name, default in self.defaults.items():
        if default is required and name not in param_group:
            raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                                name)
        else:
            param_group.setdefault(name, default)

    params = param_group['params']
    if len(params) != len(set(params)):
        warnings.warn("optimizer contains a parameter group with duplicate parameters; "
                        "in future, this will cause an error; "
                        "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)

    param_set = set()
    for group in self.param_groups:
        param_set.update(set(group['params']))

    if not param_set.isdisjoint(set(param_group['params'])):
        raise ValueError("some parameters appear in more than one parameter group")
    #把param_group这个字典放入self.param_groups这个空列表里
    #这样初始化就完成了
    self.param_groups.append(param_group)

add_param_group中对param_group的操作其实很简单,就是将之前的优化器参数self.defaults也放入param_group里。然后再把param_group存到self.param_groups里。

接下来从一个实际的例子看看是不是这样:

import torch
X = torch.tensor([1.0],requires_grad = True)
Y = torch.tensor([2.0],requires_grad = True)
optimizer = torch.optim.SGD([X,Y],lr =0.001)
print(optimizer.param_groups)
"""
输出结果:
[
{
'params': [tensor([1.], requires_grad=True), tensor([2.], requires_grad=True)], 
'lr': 0.001, 
'momentum': 0, 
'dampening': 0, 
'weight_decay': 0, 
'nesterov': False
}
]
"""

self.state如何更新?

上面self.param_groups初始化过程介绍的差不多了,接下来考虑self.state的初始化和更新问题,因为之前说过self.state每一次迭代都会更新。而优化器的更新操作是放在step这个方法里的,但是optimizer基类并不会实现step这个方法,需要每一个子类自己去实现。所以我这里以SGD为例介绍一下优化器的更新流程。

传统的SGD肯定是不需要使用self.state的,PyTorch这里的SGD只有在带动量的情况下会需要使用self.state。动量的意思简单来说就是存储过去的梯度信息。这样相比于SGD只基于当前梯度进行更新,带动量的SGD可以基于当前+过去的梯度进行更新,收敛更快。

class SGD(Optimizer):
@torch.no_grad()
def step(self, closure=None):
    """Performs a single optimization step.

    Args:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        #存储有梯度的参数
        params_with_grad = []
        #存储参数对应的梯度
        d_p_list = []
        #存储动量
        momentum_buffer_list = []
        #正则化系数
        weight_decay = group['weight_decay']
        #动量系数
        momentum = group['momentum']
        #忘了
        dampening = group['dampening']
        nesterov = group['nesterov']
        #学习率
        lr = group['lr']
        #从self.state中取出momentum_buffer
        #初始化momentum_buffer_list
        #注意此时的momentum_buffer只包含过去的梯度信息
        for p in group['params']:
            if p.grad is not None:
                params_with_grad.append(p)
                d_p_list.append(p.grad)
                #自动初始化为空字典
                state = self.state[p]
                if 'momentum_buffer' not in state:
                    momentum_buffer_list.append(None)
                else:
                    momentum_buffer_list.append(state['momentum_buffer'])
        #对参数更新
        #并且更新momentum_buffer
        #该函数执行完后momentum_buffer将包含过去+现在的梯度信息
        F.sgd(params_with_grad,
                d_p_list,
                momentum_buffer_list,
                weight_decay=weight_decay,
                momentum=momentum,
                lr=lr,
                dampening=dampening,
                nesterov=nesterov)
        #momentum_buffer_list是通过append复制操作得到state里的momentum_buffer的
        #所以虽然momentum_buffer_list已经更新了,但是state里的momentum_buffer还没更新
        #所以需要同步一下,便于下一次迭代继续从state里取momentum_buffer。
        # update momentum_buffers in state
        for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
            #每个参数都对应一个state字典。
            #在这里,每个参数都对应一个动量
            state = self.state[p]
            state['momentum_buffer'] = momentum_buffer

    return loss

这里看到step方法被@torch.no_grad()装饰器修饰,是因为需要对叶子节点做inplace操作,我之前有这一部分的介绍,这里就不赘述了。

step方法对参数的更新主要分为三步:

第一步是从self.state中取出momentum_buffer转换成列表形式momentum_buffer_list。

第二步是对参数进行更新,同时momentum_buffer_list也会得到更新。

第三步是利用更新后的momentum_buffer_list对state中的momentum_buffer进行更新。

真正的更新操作都被放在了F.sgd里,这里我做了注释,大家有兴趣可以看一下。

def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        *,
        weight_decay: float,
        momentum: float,
        lr: float,
        dampening: float,
        nesterov: bool):
    r"""Functional API that performs SGD algorithm computation.

    See :class:`~torch.optim.SGD` for details.
    """
    #遍历所有参数
    for i, param in enumerate(params):
        #取出参数对应的梯度
        d_p = d_p_list[i]
        #梯度加上正则项
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)
        #取出上一次迭代得到的动量(准备更新动量)
        if momentum != 0:
            buf = momentum_buffer_list[i]
            #第一次迭代时的动量初始化
            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            #动量更新,都是inplace操作
            #mb*momentum_factor+(1-dampening)*grad
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
            # nesterov的更新方式
            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            # 常规动量的更新方式
            else:
                d_p = buf
        #参数更新
        param.add_(d_p, alpha=-lr)

总结

这里只是对optimizer中和更新相关的源码进行了介绍,不过optimizer类中还有很多其他的方法,我目前都用不到,所以就暂时不看了。

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值