基于keras的transformer learning rate schedule

大家都知道,keras的learning rate schedule是基于epoch的,对于基于steps的learning rate schedule来说,比较难实现,网上都是实现了的tf2的版本的,对于tf1版本的几乎没有,因此我写了一个基于keras2.3.1以及tf1.15的transformer learning rate schedule

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))

在keras.optimizer.Adam中,我们可以看到,p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon),因此我们要做的就是改变lr_tlr_t是输入Adamlearning_rate,没办法轻易改变,因此我的想法是将他固定为1,再在每次更新的时候,重新乘以一个新的lr_multiplier,这个lr_multiplier即为transformer learning rate。

		@K.symbolic
        def get_updates(self, loss, params):
            lr_multiplier = transformer_schedule(self.iterations,
                                                      self.start_step,
                                                      self.warmup_steps,
                                                      self.d_model)

            old_update = K.update

            def new_update(x, new_x):
                if is_one_of(x, params):
                    new_x = x + (new_x - x) * lr_multiplier
                return old_update(x, new_x)

            K.update = new_update
            updates = super(NewOptimizer, self).get_updates(loss, params)
            K.update = old_update

            return updates

我们的主要做法是通过设定一个新的new_update函数,将p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)改写成p + (p_t-p)*lr_multiplier = p - lr_t * lr_multiplier * m_t / (K.sqrt(v_t) + self.epsilon)。如此一来,keras中的transformer learning rate schedule就成功实现了。

完整版本的keras transformer learning rate schedule已经开源在了keras-transformer-schedual,欢迎大家使用

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值