warm up + 自定义学习率策略(如cos)
from math import cos, pi
from torch.optim.lr_scheduler import _LRScheduler
class CustomScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
self.eta_min = eta_min
super(CustomScheduler, self).__init__(optimizer, -1, False)
self.last_epoch = last_epoch
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = self.func()
return [(self.base_lr-self.eta_min) * alpha + self.eta_min for _ in self.optimizer.param_groups]
def func(self):
alpha = (
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
class PolyScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return alpha
class CosineScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = cos(
pi / 2
* float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha