Lookahead 优化算法是Adam的作者继Adam之后的又一力作,论文可以参见https://arxiv.org/abs/1907.08610
这篇博客先不讲述Lookahead具体原理,先介绍如何将Lookahead集成到现有的代码中。
本人在三个项目中(涉及风格转换、物体识别)使用该优化器,最大的感受就是使用该优化器十分有利于模型收敛,原本不收敛或者收敛过慢的模型在使用lookahead后可以看到明显的收敛情况,并且最终的效果能够满足最初设计的要求。
总所周知,Adam因为其具有较好的适应性,被广泛用于各类模型的优化;其参数简单,调参方便的特点一直为大家所喜爱,尤其对于初学者较为友好。Lookahead 也继承了Adam的优点。lookahead的Pytorch版本代码如下所示:后续会针对代码进行原理讲解,该代码在Github上可以找到。
from collections import defaultdict
from torch.optim import Optimizer
import torch
class Lookahead(Optimizer):
def __init__(self, optimizer, k=5, alpha=0.5):
self.optimizer = optimizer
self.k = k
self.alpha = alpha
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)
self.fast_state = self.optimizer.state
for group in self.param_groups:
group["counter"] = 0
def update(self, group):
for fast in group["params"]:
param_state = self.state[fast]
if "slow_param" not in param_state:
param_state["slow_param"] = torch.zeros_like(fast.data)
param_state["slow_param"].copy_(fast.data)
slow = param_state["slow_param"]
slow += (fast.data - slow) * self.alpha
fast.data.copy_(slow)
def update_lookahead(self):
for group in self.param_groups:
self.update(group)
def step(self, closure=None):
loss = self.optimizer.step(closure)
for group in self.param_groups:
if group["counter"] == 0:
self.update(group)
group["counter"] += 1
if group["counter"] >= self.k:
group["counter"] = 0
return loss
def state_dict(self):
fast_state_dict = self.optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict["state"]
param_groups = fast_state_dict["param_groups"]
return {
"fast_state": fast_state,
"slow_state": slow_state,
"param_groups": param_groups,
}
def load_state_dict(self, state_dict):
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict["param_groups"],
}
fast_state_dict = {
"state": state_dict["fast_state"],
"param_groups": state_dict["param_groups"],
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.optimizer.load_state_dict(fast_state_dict)
self.fast_state = self.optimizer.state
def add_param_group(self, param_group):
param_group["counter"] = 0
self.optimizer.add_param_group(param_group)
将lookahead集成在现有代码中如下操作即可:
base_optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
opt = Lookahead(base_optimizer, k=5, alpha=0.5)
此时直接将opt作为正常的优化器使用即可,就像直接使用Adam一样的步骤使用opt