Torch 梯度下降法 —— 非线性优化

环境及基本函数如下:

import logging
from typing import Callable, Optional

import torch
from tqdm import tqdm

logging.basicConfig(format='%(message)s', level=logging.INFO)
LOGGER = logging.getLogger(__name__)

torch.autograd.set_detect_anomaly(True)

非线性优化的函数分为三部分:

  • __new__:与 __init__ 的作用相似,最大的不同是有返回值,可以使这个类能像函数一样被调用,返回值是 best_variant (最优自变量)、min_loss (最小损失值)、log (损失值的历史记录)
  • main:输入是 patience (耐心值)、max_iter (最大迭代次数),这两个参数控制了迭代次数和终止条件
  1. 指定次数:当 patience = None 时,在迭代指定次数 (max_iter) 后退出
  2. 贪心模式:当 max_iter = None 时,会在 min_loss 不再变化并连续 50 (patience) 次后才退出
  3. 懒惰模式:如果 max_iter 和 patience 都不为 None,则限定最大迭代次数,并在 min_loss 不再变化并连续 50 (patience) 次后,提前退出
  • update:使用梯度下降法对变量进行更新,并记录更优的变量,同时把损失值记录进日志中

eval_fcn 的作用:log (损失值的历史记录) 默认是使用 loss_fcn 计算得到的,如果设置了 eval_fcn,则会存储 eval_fcn 计算的损失值,例如:使用 SIoU 损失对锚框进行回归,但是为了验证 IoU 损失是否随着 SIoU 损失的下降而下降,设置 eval_fcn 为 IoU 损失

class minimize:
    ''' variant: 作为变量的 tensor
        loss_fcn: 以 variant 为输入, loss 为输出的函数
        lr: 学习率
        patience: 允许 loss 无进展的次数
        eval_fcn: 需要记录的损失函数
        max_iter: 最大迭代次数
        prefix: 进度条前缀
        title: 输出标题
        return: 最优变量的 tensor, 最小 loss 值, loss 日志'''

    def __new__(cls,
                variant: torch.tensor,
                loss_fcn: Callable,
                lr: float,
                eval_fcn=None,
                patience: Optional[int] = 50,
                max_iter: Optional[int] = None,
                prefix: str = 'Minimize',
                title: bool = True,
                leave: bool = True):
        assert patience or max_iter
        # 初始化变量
        variant.requires_grad = True
        cls.variant = variant
        cls.optimizer = torch.optim.Adam([variant], lr=lr)
        # 记录最优变量
        cls.min_loss, cls.best_variant, cls.log = float('inf'), None, []
        if title: LOGGER.info(('%10s' * 3) % ('', 'cur_loss', 'min_loss'))
        # 设置类变量
        cls.prefix = prefix
        cls.leave = leave
        instance = object.__new__(cls)
        instance.loss_fcn, instance.eval_fcn = loss_fcn, eval_fcn
        instance.main(patience, max_iter)
        return instance.best_variant, instance.min_loss, instance.log

    def main(self, patience, max_iter):
        # 初始化迭代参数
        pbar = tqdm(range(max_iter if max_iter else patience), leave=self.leave)
        angry = 0 if patience else None
        if not max_iter:
            # 贪心模式
            while angry != patience:
                is_better = self.update(pbar)
                angry = 0 if is_better else angry + 1
                pbar.reset() if is_better else pbar.update()
        else:
            # 指定次数
            for _ in pbar:
                is_better = self.update(pbar)
                # 懒惰模式
                if patience:
                    angry = 0 if is_better else angry + 1
                    if angry == patience: break
        pbar.close()

    def update(self, pbar):
        is_better = False
        # 计算损失值, 记入日志
        loss = self.loss_fcn(self.variant)
        loss_value = loss.item() if not self.eval_fcn else self.eval_fcn(self.variant).item()
        self.log.append(loss_value)
        # 保存更优的变量
        if loss_value < self.min_loss:
            self.min_loss, self.best_variant = loss_value, self.variant.clone().detach()
            is_better = True
        # 反向传播梯度, 更新变量
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        pbar.set_description(('%10s' + '%10.4g' * 2) % (self.prefix, loss_value, self.min_loss))
        return is_better

求解示例:

  • 原函数:y(x)=x+0.3x^2+0.5x^3+4\sin{x}+noise
  • 近似函数:y(x)=ax+bx^2+cx^3+de^{-x}
if __name__ == '__main__':
    import matplotlib.pyplot as plt

    x = torch.linspace(-3, 3, 50)
    # 原函数: x + 0.3 x^2 - 0.5 x^3 + 4 sin(x) + 噪声
    y = x + 0.3 * x ** 2 - 0.5 * x ** 3 + 4 * torch.sin(x) + 5 * (torch.rand(len(x)) - 0.5)
    # 绘制原函数计算得到的散点
    plt.scatter(x, y, c='deepskyblue', label='true')


    def cal_y(variant, x):
        x = torch.stack([x, x ** 2, x ** 3, torch.exp(-x)], dim=1)
        # 矩阵乘法: [bs, 4] × [4, ] -> [bs, ]
        y = x @ variant
        return y


    def loss(variant):
        pred_y = cal_y(variant, x)
        # 均方差 (MSE) 损失
        return ((y - pred_y) ** 2).mean()


    # 拟合函数: a x + b x^2 + c x^3 + d e^x
    # 优化目标: [a, b, c, d]
    best_var, *_ = minimize(torch.ones(4), loss_fcn=loss, lr=1e-1, patience=50, max_iter=2000)
    # 绘制预测值
    print(best_var)
    plt.plot(x, cal_y(best_var, x), c='orange', label='pred')

    plt.legend()
    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

荷碧TongZJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值