余弦退火算法遇到的问题

参考大佬博客

http://t.csdn.cn/TddFi

公式

gif.latex?%5Calpha%20_%7Bt%7D%20%3D%200.5%20*%20%5Calpha%20_%7B0%7D%20%5Ccdot%20%281+cos%28%5Cfrac%7Bt%5Ccdot%20%5Cpi%20%7D%7BT%7D%29%29

 其中,gif.latex?%5Calpha%20_%7B0%7D 代表初始学习率,gif.latex?t 是指当前是第几个 step,gif.latex?T 是指多少个 step 之后学习率衰减为0

这里仿照大佬的文章向自己的模型中添加学习率余弦退火衰减的时候,因为最后我需要调用.save方法,因此在重写学习率变化类的时候,也需要重写类中的get_config方法!!这里很重要!!在重写方法的时候方法的返回值也有格式要求,我把具体问题放在最后讲解。

 因为,现在处于测试阶段,所以我想尝试不同的学习率优化方法对训练的印象,这次选择了余弦退火算法,但是遇到了一些问题,导致最后在保存模型的时候发生了错误。

这里直接附上重写的学习率的类,请注意此类中的最后一个方法!

import tensorflow as tf
import math
tf.config.experimental_run_functions_eagerly(True)
# 继承自定义学习率的类
class CosineWarmupDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    '''
    initial_lr: 初始的学习率
    min_lr: 学习率的最小值
    max_lr: 学习率的最大值
    warmup_step: 线性上升部分需要的step
    total_step: 第一个余弦退火周期需要对总step
    multi: 下个周期相比于上个周期调整的倍率
    print_step: 多少个step并打印一次学习率
    '''
    # 初始化
    def __init__(self, initial_lr, min_lr, warmup_step, total_step, multi, print_step):
        # 继承父类的初始化方法
        super(CosineWarmupDecay, self).__init__()
        
        # 属性分配
        self.initial_lr = tf.cast(initial_lr, dtype=tf.float32)
        self.min_lr = tf.cast(min_lr, dtype=tf.float32)
        self.warmup_step = warmup_step  # 初始为第一个周期的线性段的step
        self.total_step = total_step    # 初始为第一个周期的总step
        self.multi = multi
        self.print_step = print_step
        
        # 保存每一个step的学习率
        self.learning_rate_list = []
        # 当前步长
        self.step = 0
        
        
    # 前向传播, 训练时传入当前step,但是上面已经定义了一个,这个step用不上
    def __call__(self, step):
        
        # 如果当前step达到了当前周期末端就调整
        if  self.step>=self.total_step:
            
            # 乘上倍率因子后会有小数,这里要注意
            # 调整一个周期中线性部分的step长度
            self.warmup_step = self.warmup_step * (1 + self.multi)
            # 调整一个周期的总step长度
            self.total_step = self.total_step * (1 + self.multi)
            
            # 重置step,从线性部分重新开始
            self.step = 0
            
        # 余弦部分的计算公式
        decayed_learning_rate = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) *       \
                                (1 + tf.math.cos(math.pi * (self.step-self.warmup_step) /        \
                                  (self.total_step-self.warmup_step)))
        
        # 计算线性上升部分的增长系数k
        k = (self.initial_lr - self.min_lr) / self.warmup_step 
        # 线性增长线段 y=kx+b
        warmup = k * self.step + self.min_lr
        
        # 以学习率峰值点横坐标为界,左侧是线性上升,右侧是余弦下降
        decayed_learning_rate = tf.where(self.step<self.warmup_step, warmup, decayed_learning_rate)
        
        
        # 每个epoch打印一次学习率
        if step % self.print_step == 0:
            # 打印当前step的学习率
            print('learning_rate has changed to: ', decayed_learning_rate.numpy().item())
        
        # 每个step保存一次学习率
        self.learning_rate_list.append(decayed_learning_rate.numpy().item())
 
        # 计算完当前学习率后step加一用于下一次
        self.step = self.step + 1
        
        # 返回调整后的学习率
        return decayed_learning_rate
    #若最后不调用.save方法保存模型可以不用重写
    def get_config(self):
        config = {
        '''
        如果这里是一个数的话最后会报错:TypeError: ('Not JSON Serializable:', )所以我改成了一个包含学习率的列表
        '''
        'learning_rate_list':self.learning_rate_list,
        }       
        return config

还有如果没有重写get_config则会报错:

NotImplementedError: Learning rate schedule must override get_config

方便理解这里附上上述重写学习率类的父类LearningRateSchedule

@keras_export("keras.optimizers.schedules.LearningRateSchedule")
class LearningRateSchedule(object):
  """A serializable learning rate decay schedule.

  `LearningRateSchedule`s can be passed in as the learning rate of optimizers in
  `tf.keras.optimizers`. They can be serialized and deserialized using
  `tf.keras.optimizers.schedules.serialize` and
  `tf.keras.optimizers.schedules.deserialize`.
  """

  @abc.abstractmethod
  def __call__(self, step):
    raise NotImplementedError("Learning rate schedule must override __call__")

  @abc.abstractmethod
  def get_config(self):
    raise NotImplementedError("Learning rate schedule must override get_config")

  @classmethod
  def from_config(cls, config):
    """Instantiates a `LearningRateSchedule` from its config.

    Args:
        config: Output of `get_config()`.

    Returns:
        A `LearningRateSchedule` instance.
    """
    return cls(**config)


@keras_export("keras.optimizers.schedules.ExponentialDecay")

目前遇到的问题都已解决,后续再有问题会继续添加 

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值