import tensorflow as tf
import math
import matplotlib.pyplot as plt
class CosineWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, warmup_slope, warmup_steps, cosine_steps):
super().__init__()
self.warmup_slope = tf.cast(warmup_slope, dtype=tf.float32)
self.warmup_steps = tf.cast(warmup_steps, dtype=tf.float32)
self.cosine_steps = tf.cast(cosine_steps, dtype=tf.float32)
self.steps = self.warmup_steps + self.cosine_steps
def __call__(self, step):
if step%self.steps < self.warmup_steps:
return self.warmup_slope * (step%s
余弦退火学习率衰减策略
最新推荐文章于 2024-10-07 06:30:00 发布