pytorch学习之:使用 warm up 方法构造优化调度器优化神经网络参数

文章目录

代码

  • 构造优化调度器,根据当前的 epoch 调整训练的 learning rate

class ScheduledOptim:
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self
                 , optimizer
                 , n_warmup_epochs=6
                 , sustain_epochs=0
                 , lr_max=1e-3
                 , lr_min=1e-5
                 , lr_exp_decay=0.4):

        self._optimizer = optimizer
        self.n_warmup_epochs = n_warmup_epochs
        self.sustain_epochs = sustain_epochs
        self.init_lr = lr_min
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.lr_exp_decay = lr_exp_decay

    def step_and_update_lr(self, epoch):
        "Step with the inner optimizer"
        self._update_learning_rate(epoch)
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _update_learning_rate(self, epoch):
        ''' Learning rate scheduling per epoch '''

        if epoch < self.n_warmup_epochs:
            lr = (self.lr_max - self.lr_min) / self.n_warmup_epochs * epoch + self.init_lr
        elif epoch < self.n_warmup_epochs + self.sustain_epochs:
            lr = self.lr_max
        else:
            lr = (self.lr_max - self.lr_min) \
                 * self.lr_exp_decay ** (epoch - self.n_warmup_epochs - self.sustain_epochs) \
                 + self.lr_min
        # return lr
        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

    def draw(self, epochs):
    	"""
    	画出优化器的变化趋势 plot
    	"""
        lrs = []
        for i in range(epochs):
            lr = self._update_learning_rate(i)
            lrs.append(lr)
        import matplotlib.pylab as plt
        plt.plot(lrs)
        plt.show()

调用

optimizer = Adam(model.parameters()
                            , lr=5e-4
                            , eps=1e-16
                            , betas=(0.9, 0.999)
                            )
optim_schedule = ScheduledOptim(optimizer)

for epoch in range(epochs):
	# 其他代码....
	
	# 通过 schedule 根据不同的 epoch 进行 lr 更新
	optim_schedule.zero_grad()
	loss.backward()
	optim_schedule.step_and_update_lr(epoch)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暖仔会飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值