pytorch pyro更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵)

在机器学习和深度学习中,优化器是用来更新模型参数以最小化损失函数的算法。通常,优化器会计算损失函数相对于参数的一阶导数(梯度),然后根据这些梯度来更新参数。但是,更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵),来指导参数的更新

 

关于使用更高阶导数的优化器基类的描述。在机器学习和深度学习中,优化器是用来更新模型参数以最小化损失函数的算法。通常,优化器会计算损失函数相对于参数的一阶导数(梯度),然后根据这些梯度来更新参数。但是,更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵),来指导参数的更新。

这段描述中的关键点包括:

  1. 使用torch.autograd.grad而不是torch.Tensor.backwardtorch.autograd.grad是PyTorch中的一个函数,它可以用来计算张量相对于其他张量的导数。这与torch.Tensor.backward不同,后者是自动求导机制的一部分,通常用于计算梯度。

  2. 不同的接口:由于高阶优化器需要计算更高阶的导数,它们需要一个不同的接口。在这个接口中,step方法接受一个损失张量作为输入,并在优化器内部触发一次或多次反向传播。

  3. 派生类必须实现step方法:这意味着任何从这个基类派生的优化器类都需要提供自己的step方法实现,以计算导数并就地更新参数。

  4. 示例代码:示例展示了如何使用这种优化器。首先,通过poutine.trace获取模型的跟踪,然后计算负对数概率之和作为损失。接着,从跟踪中提取参数,并调用优化器的step方法来更新这些参数。

简而言之,这段代码描述了一个用于高级优化的基类,它允许开发者实现使用更高阶导数的自定义优化器。这种类型的优化器可能在某些情况下比传统的一阶优化器更有效,尤其是在参数更新需要更精细控制的场景中。

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List

import torch

from pyro.ops.newton import newton_step
from pyro.optim.optim import PyroOptim


class MultiOptimizer:
    """
    Base class of optimizers that make use of higher-order derivatives.

    Higher-order optimizers generally use :func:`torch.autograd.grad` rather
    than :meth:`torch.Tensor.backward`, and therefore require a different
    interface from usual Pyro and PyTorch optimizers. In this interface,
    the :meth:`step` method inputs a ``loss`` tensor to be differentiated,
    and backpropagation is triggered one or more times inside the optimizer.

    Derived classes must implement :meth:`step` to compute derivatives and
    update parameters in-place.

    Example::

        tr = poutine.trace(model).get_trace(*args, **kwargs)
        loss = -tr.log_prob_sum()
        params = {name: site['value'].unconstrained()
                  for name, site in tr.nodes.items()
                  if site['type'] == 'param'}
        optim.step(loss, params)
    """

    def step(self, loss: torch.Tensor, params: Dict) -> None:
        """
        Performs an in-place optimization step on parameters given a
        differentiable ``loss`` tensor.

        Note that this detaches the updated tensors.

        :param torch.Tensor loss: A differentiable tensor to be minimized.
            Some optimizers require this to be differentiable multiple times.
        :param dict params: A dictionary mapping param name to unconstrained
            value as stored in the param store.
        """
        updated_values = self.get_step(loss, params)
        for name, value in params.items():
            with torch.no_grad():
                # we need to detach because updated_value may depend on value
                value.copy_(updated_values[name].detach())

    def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:
        """
        Computes an optimization step of parameters given a differentiable
        ``loss`` tensor, returning the updated values.

        Note that this preserves derivatives on the updated tensors.

        :param torch.Tensor loss: A differentiable tensor to be minimized.
            Some optimizers require this to be differentiable multiple times.
        :param dict params: A dictionary mapping param name to unconstrained
            value as stored in the param store.
        :return: A dictionary mapping param name to updated unconstrained
            value.
        :rtype: dict
        """
        raise NotImplementedError


class PyroMultiOptimizer(MultiOptimizer):
    """
    Facade to wrap :class:`~pyro.optim.optim.PyroOptim` objects
    in a :class:`MultiOptimizer` interface.
    """

    def __init__(self, optim: PyroOptim) -> None:
        if not isinstance(optim, PyroOptim):
            raise TypeError(
                "Expected a PyroOptim object but got a {}".format(type(optim))
            )
        self.optim = optim

    def step(self, loss: torch.Tensor, params: Dict) -> None:
        values = params.values()
        grads = torch.autograd.grad(loss, values, create_graph=True)  # type: ignore
        for x, g in zip(values, grads):
            x.grad = g
        self.optim(values)


class TorchMultiOptimizer(PyroMultiOptimizer):
    """
    Facade to wrap :class:`~torch.optim.Optimizer` objects
    in a :class:`MultiOptimizer` interface.
    """

    def __init__(self, optim_constructor: torch.optim.Optimizer, optim_args: Dict):
        optim = PyroOptim(optim_constructor, optim_args)
        super().__init__(optim)


class MixedMultiOptimizer(MultiOptimizer):
    """
    Container class to combine different :class:`MultiOptimizer` instances for
    different parameters.

    :param list parts: A list of ``(names, optim)`` pairs, where each
        ``names`` is a list of parameter names, and each ``optim`` is a
        :class:`MultiOptimizer` or :class:`~pyro.optim.optim.PyroOptim` object
        to be used for the named parameters. Together the ``names`` should
        partition up all desired parameters to optimize.
    :raises ValueError: if any name is optimized by multiple optimizers.
    """

    def __init__(self, parts: List) -> None:
        optim_dict: Dict = {}
        self.parts = []
        for names_part, optim in parts:
            if isinstance(optim, PyroOptim):
                optim = PyroMultiOptimizer(optim)
            for name in names_part:
                if name in optim_dict:
                    raise ValueError(
                        "Attempted to optimize parameter '{}' by two different optimizers: "
                        "{} vs {}".format(name, optim_dict[name], optim)
                    )
                optim_dict[name] = optim
            self.parts.append((names_part, optim))

    def step(self, loss: torch.Tensor, params: Dict):
        for names_part, optim in self.parts:
            optim.step(loss, {name: params[name] for name in names_part})

    def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:
        updated_values = {}
        for names_part, optim in self.parts:
            updated_values.update(
                optim.get_step(loss, {name: params[name] for name in names_part})
            )
        return updated_values


class Newton(MultiOptimizer):
    """
    Implementation of :class:`MultiOptimizer` that performs a Newton update
    on batched low-dimensional variables, optionally regularizing via a
    per-parameter ``trust_radius``. See :func:`~pyro.ops.newton.newton_step`
    for details.

    The result of :meth:`get_step` will be differentiable, however the
    updated values from :meth:`step` will be detached.

    :param dict trust_radii: a dict mapping parameter name to radius of trust
        region. Missing names will use unregularized Newton update, equivalent
        to infinite trust radius.
    """

    def __init__(self, trust_radii: Dict = {}):
        self.trust_radii = trust_radii

    def get_step(self, loss: torch.Tensor, params: Dict):
        updated_values = {}
        for name, value in params.items():
            trust_radius = self.trust_radii.get(name)  # type: ignore
            updated_value, cov = newton_step(loss, value, trust_radius)
            updated_values[name] = updated_value
        return updated_values

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值