warmup lr策略就是在网络训练初期用比较小的学习率,线性增长到初始设定的学习率。
大概就是下面这个趋势,从0上升到0.01,再按照正常的学习率调整策略训练。
import torch
from torch.optim.lr_scheduler import _LRScheduler
class WarmUpLR(_LRScheduler):
"""warmup_training learning rate scheduler
Args:
optimizer: optimzier(e.g. SGD)
total_iters: totoal_iters of warmup phase
"""
def __init__(self, optimizer, total_iters, last_epoch=-1):
self.total_iters = total_iters
super().__init__(optimizer, last_epoch)
def get_lr(self):
"""we will use the first m batches, and set the learning
rate to base_lr * m / total_iters