代码
- 构造优化调度器,根据当前的
epoch
调整训练的 learning rate
class ScheduledOptim:
'''A simple wrapper class for learning rate scheduling'''
def __init__(self
, optimizer
, n_warmup_epochs=6
, sustain_epochs=0
, lr_max=1e-3
, lr_min=1e-5
, lr_exp_decay=0.4):
self._optimizer = optimizer
self.n_warmup_epochs = n_warmup_epochs
self.sustain_epochs = sustain_epochs
self.init_lr = lr_min
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_exp_decay = lr_exp_decay
def step_and_update_lr(self, epoch):
"Step with the inner optimizer"
self._update_learning_rate(epoch)
self._optimizer.step()
def zero_grad(self):
"Zero out the gradients by the inner optimizer"
self._optimizer.zero_grad()
def _update_learning_rate(self, epoch):
''' Learning rate scheduling per epoch '''
if epoch < self.n_warmup_epochs:
lr = (self.lr_max - self.lr_min) / self.n_warmup_epochs * epoch + self.init_lr
elif epoch < self.n_warmup_epochs + self.sustain_epochs:
lr = self.lr_max
else:
lr = (self.lr_max - self.lr_min) \
* self.lr_exp_decay ** (epoch - self.n_warmup_epochs - self.sustain_epochs) \
+ self.lr_min
for param_group in self._optimizer.param_groups:
param_group['lr'] = lr
def draw(self, epochs):
"""
画出优化器的变化趋势 plot
"""
lrs = []
for i in range(epochs):
lr = self._update_learning_rate(i)
lrs.append(lr)
import matplotlib.pylab as plt
plt.plot(lrs)
plt.show()
调用
optimizer = Adam(model.parameters()
, lr=5e-4
, eps=1e-16
, betas=(0.9, 0.999)
)
optim_schedule = ScheduledOptim(optimizer)
for epoch in range(epochs):
optim_schedule.zero_grad()
loss.backward()
optim_schedule.step_and_update_lr(epoch)